File size: 28,960 Bytes
d5001fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
import json
import os
import re
import tarfile
import zipfile
from pathlib import Path
from shutil import copyfile, rmtree
from typing import Dict, List, Tuple

import fsspec
import requests
from tqdm import tqdm

from TTS.config import load_config, read_json_with_comments
from TTS.utils.generic_utils import get_user_data_dir

LICENSE_URLS = {
    "cc by-nc-nd 4.0": "https://creativecommons.org/licenses/by-nc-nd/4.0/",
    "mpl": "https://www.mozilla.org/en-US/MPL/2.0/",
    "mpl2": "https://www.mozilla.org/en-US/MPL/2.0/",
    "mpl 2.0": "https://www.mozilla.org/en-US/MPL/2.0/",
    "mit": "https://choosealicense.com/licenses/mit/",
    "apache 2.0": "https://choosealicense.com/licenses/apache-2.0/",
    "apache2": "https://choosealicense.com/licenses/apache-2.0/",
    "cc-by-sa 4.0": "https://creativecommons.org/licenses/by-sa/4.0/",
    "cpml": "https://coqui.ai/cpml.txt",
}


class ModelManager(object):
    tqdm_progress = None
    """Manage TTS models defined in .models.json.
    It provides an interface to list and download
    models defines in '.model.json'

    Models are downloaded under '.TTS' folder in the user's
    home path.

    Args:
        models_file (str): path to .model.json file. Defaults to None.
        output_prefix (str): prefix to `tts` to download models. Defaults to None
        progress_bar (bool): print a progress bar when donwloading a file. Defaults to False.
        verbose (bool): print info. Defaults to True.
    """

    def __init__(self, models_file=None, output_prefix=None, progress_bar=False, verbose=True):
        super().__init__()
        self.progress_bar = progress_bar
        self.verbose = verbose
        if output_prefix is None:
            self.output_prefix = get_user_data_dir("tts")
        else:
            self.output_prefix = os.path.join(output_prefix, "tts")
        self.models_dict = None
        if models_file is not None:
            self.read_models_file(models_file)
        else:
            # try the default location
            path = Path(__file__).parent / "../.models.json"
            self.read_models_file(path)

    def read_models_file(self, file_path):
        """Read .models.json as a dict

        Args:
            file_path (str): path to .models.json.
        """
        self.models_dict = read_json_with_comments(file_path)

    def _list_models(self, model_type, model_count=0):
        if self.verbose:
            print("\n Name format: type/language/dataset/model")
        model_list = []
        for lang in self.models_dict[model_type]:
            for dataset in self.models_dict[model_type][lang]:
                for model in self.models_dict[model_type][lang][dataset]:
                    model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
                    output_path = os.path.join(self.output_prefix, model_full_name)
                    if self.verbose:
                        if os.path.exists(output_path):
                            print(f" {model_count}: {model_type}/{lang}/{dataset}/{model} [already downloaded]")
                        else:
                            print(f" {model_count}: {model_type}/{lang}/{dataset}/{model}")
                    model_list.append(f"{model_type}/{lang}/{dataset}/{model}")
                    model_count += 1
        return model_list

    def _list_for_model_type(self, model_type):
        models_name_list = []
        model_count = 1
        models_name_list.extend(self._list_models(model_type, model_count))
        return models_name_list

    def list_models(self):
        models_name_list = []
        model_count = 1
        for model_type in self.models_dict:
            model_list = self._list_models(model_type, model_count)
            models_name_list.extend(model_list)
        return models_name_list

    def model_info_by_idx(self, model_query):
        """Print the description of the model from .models.json file using model_idx

        Args:
            model_query (str): <model_tye>/<model_idx>
        """
        model_name_list = []
        model_type, model_query_idx = model_query.split("/")
        try:
            model_query_idx = int(model_query_idx)
            if model_query_idx <= 0:
                print("> model_query_idx should be a positive integer!")
                return
        except:
            print("> model_query_idx should be an integer!")
            return
        model_count = 0
        if model_type in self.models_dict:
            for lang in self.models_dict[model_type]:
                for dataset in self.models_dict[model_type][lang]:
                    for model in self.models_dict[model_type][lang][dataset]:
                        model_name_list.append(f"{model_type}/{lang}/{dataset}/{model}")
                        model_count += 1
        else:
            print(f"> model_type {model_type} does not exist in the list.")
            return
        if model_query_idx > model_count:
            print(f"model query idx exceeds the number of available models [{model_count}] ")
        else:
            model_type, lang, dataset, model = model_name_list[model_query_idx - 1].split("/")
            print(f"> model type : {model_type}")
            print(f"> language supported : {lang}")
            print(f"> dataset used : {dataset}")
            print(f"> model name : {model}")
            if "description" in self.models_dict[model_type][lang][dataset][model]:
                print(f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}")
            else:
                print("> description : coming soon")
            if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
                print(f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}")

    def model_info_by_full_name(self, model_query_name):
        """Print the description of the model from .models.json file using model_full_name

        Args:
            model_query_name (str): Format is <model_type>/<language>/<dataset>/<model_name>
        """
        model_type, lang, dataset, model = model_query_name.split("/")
        if model_type in self.models_dict:
            if lang in self.models_dict[model_type]:
                if dataset in self.models_dict[model_type][lang]:
                    if model in self.models_dict[model_type][lang][dataset]:
                        print(f"> model type : {model_type}")
                        print(f"> language supported : {lang}")
                        print(f"> dataset used : {dataset}")
                        print(f"> model name : {model}")
                        if "description" in self.models_dict[model_type][lang][dataset][model]:
                            print(
                                f"> description : {self.models_dict[model_type][lang][dataset][model]['description']}"
                            )
                        else:
                            print("> description : coming soon")
                        if "default_vocoder" in self.models_dict[model_type][lang][dataset][model]:
                            print(
                                f"> default_vocoder : {self.models_dict[model_type][lang][dataset][model]['default_vocoder']}"
                            )
                    else:
                        print(f"> model {model} does not exist for {model_type}/{lang}/{dataset}.")
                else:
                    print(f"> dataset {dataset} does not exist for {model_type}/{lang}.")
            else:
                print(f"> lang {lang} does not exist for {model_type}.")
        else:
            print(f"> model_type {model_type} does not exist in the list.")

    def list_tts_models(self):
        """Print all `TTS` models and return a list of model names

        Format is `language/dataset/model`
        """
        return self._list_for_model_type("tts_models")

    def list_vocoder_models(self):
        """Print all the `vocoder` models and return a list of model names

        Format is `language/dataset/model`
        """
        return self._list_for_model_type("vocoder_models")

    def list_vc_models(self):
        """Print all the voice conversion models and return a list of model names

        Format is `language/dataset/model`
        """
        return self._list_for_model_type("voice_conversion_models")

    def list_langs(self):
        """Print all the available languages"""
        print(" Name format: type/language")
        for model_type in self.models_dict:
            for lang in self.models_dict[model_type]:
                print(f" >: {model_type}/{lang} ")

    def list_datasets(self):
        """Print all the datasets"""
        print(" Name format: type/language/dataset")
        for model_type in self.models_dict:
            for lang in self.models_dict[model_type]:
                for dataset in self.models_dict[model_type][lang]:
                    print(f" >: {model_type}/{lang}/{dataset}")

    @staticmethod
    def print_model_license(model_item: Dict):
        """Print the license of a model

        Args:
            model_item (dict): model item in the models.json
        """
        if "license" in model_item and model_item["license"].strip() != "":
            print(f" > Model's license - {model_item['license']}")
            if model_item["license"].lower() in LICENSE_URLS:
                print(f" > Check {LICENSE_URLS[model_item['license'].lower()]} for more info.")
            else:
                print(" > Check https://opensource.org/licenses for more info.")
        else:
            print(" > Model's license - No license information available")

    def _download_github_model(self, model_item: Dict, output_path: str):
        if isinstance(model_item["github_rls_url"], list):
            self._download_model_files(model_item["github_rls_url"], output_path, self.progress_bar)
        else:
            self._download_zip_file(model_item["github_rls_url"], output_path, self.progress_bar)

    def _download_hf_model(self, model_item: Dict, output_path: str):
        if isinstance(model_item["hf_url"], list):
            self._download_model_files(model_item["hf_url"], output_path, self.progress_bar)
        else:
            self._download_zip_file(model_item["hf_url"], output_path, self.progress_bar)

    def download_fairseq_model(self, model_name, output_path):
        URI_PREFIX = "https://coqui.gateway.scarf.sh/fairseq/"
        _, lang, _, _ = model_name.split("/")
        model_download_uri = os.path.join(URI_PREFIX, f"{lang}.tar.gz")
        self._download_tar_file(model_download_uri, output_path, self.progress_bar)

    @staticmethod
    def set_model_url(model_item: Dict):
        model_item["model_url"] = None
        if "github_rls_url" in model_item:
            model_item["model_url"] = model_item["github_rls_url"]
        elif "hf_url" in model_item:
            model_item["model_url"] = model_item["hf_url"]
        elif "fairseq" in model_item["model_name"]:
            model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/"
        elif "xtts" in model_item["model_name"]:
            model_item["model_url"] = "https://coqui.gateway.scarf.sh/xtts/"
        return model_item

    def _set_model_item(self, model_name):
        # fetch model info from the dict
        if "fairseq" in model_name:
            model_type = "tts_models"
            lang = model_name.split("/")[1]
            model_item = {
                "model_type": "tts_models",
                "license": "CC BY-NC 4.0",
                "default_vocoder": None,
                "author": "fairseq",
                "description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
            }
            model_item["model_name"] = model_name
        elif "xtts" in model_name and len(model_name.split("/")) != 4:
            # loading xtts models with only model name (e.g. xtts_v2.0.2)
            # check model name has the version number with regex
            version_regex = r"v\d+\.\d+\.\d+"
            if re.search(version_regex, model_name):
                model_version = model_name.split("_")[-1]
            else:
                model_version = "main"
            model_type = "tts_models"
            lang = "multilingual"
            dataset = "multi-dataset"
            model = model_name
            model_item = {
                "default_vocoder": None,
                "license": "CPML",
                "contact": "info@coqui.ai",
                "tos_required": True,
                "hf_url": [
                    f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/model.pth",
                    f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json",
                    f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json",
                    f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5",
                    f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/speakers_xtts.pth",
                ],
            }
        else:
            # get model from models.json
            model_type, lang, dataset, model = model_name.split("/")
            model_item = self.models_dict[model_type][lang][dataset][model]
            model_item["model_type"] = model_type

        model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
        md5hash = model_item["model_hash"] if "model_hash" in model_item else None
        model_item = self.set_model_url(model_item)
        return model_item, model_full_name, model, md5hash

    @staticmethod
    def ask_tos(model_full_path):
        """Ask the user to agree to the terms of service"""
        tos_path = os.path.join(model_full_path, "tos_agreed.txt")
        print(" > You must confirm the following:")
        print(' | > "I have purchased a commercial license from Coqui: licensing@coqui.ai"')
        print(' | > "Otherwise, I agree to the terms of the non-commercial CPML: https://coqui.ai/cpml" - [y/n]')
        answer = input(" | | > ")
        if answer.lower() == "y":
            with open(tos_path, "w", encoding="utf-8") as f:
                f.write("I have read, understood and agreed to the Terms and Conditions.")
            return True
        return False

    @staticmethod
    def tos_agreed(model_item, model_full_path):
        """Check if the user has agreed to the terms of service"""
        if "tos_required" in model_item and model_item["tos_required"]:
            tos_path = os.path.join(model_full_path, "tos_agreed.txt")
            if os.path.exists(tos_path) or os.environ.get("COQUI_TOS_AGREED") == "1":
                return True
            return False
        return True

    def create_dir_and_download_model(self, model_name, model_item, output_path):
        os.makedirs(output_path, exist_ok=True)
        # handle TOS
        if not self.tos_agreed(model_item, output_path):
            if not self.ask_tos(output_path):
                os.rmdir(output_path)
                raise Exception(" [!] You must agree to the terms of service to use this model.")
        print(f" > Downloading model to {output_path}")
        try:
            if "fairseq" in model_name:
                self.download_fairseq_model(model_name, output_path)
            elif "github_rls_url" in model_item:
                self._download_github_model(model_item, output_path)
            elif "hf_url" in model_item:
                self._download_hf_model(model_item, output_path)

        except requests.RequestException as e:
            print(f" > Failed to download the model file to {output_path}")
            rmtree(output_path)
            raise e
        self.print_model_license(model_item=model_item)

    def check_if_configs_are_equal(self, model_name, model_item, output_path):
        with fsspec.open(self._find_files(output_path)[1], "r", encoding="utf-8") as f:
            config_local = json.load(f)
        remote_url = None
        for url in model_item["hf_url"]:
            if "config.json" in url:
                remote_url = url
                break

        with fsspec.open(remote_url, "r", encoding="utf-8") as f:
            config_remote = json.load(f)

        if not config_local == config_remote:
            print(f" > {model_name} is already downloaded however it has been changed. Redownloading it...")
            self.create_dir_and_download_model(model_name, model_item, output_path)

    def download_model(self, model_name):
        """Download model files given the full model name.
        Model name is in the format
            'type/language/dataset/model'
            e.g. 'tts_model/en/ljspeech/tacotron'

        Every model must have the following files:
            - *.pth : pytorch model checkpoint file.
            - config.json : model config file.
            - scale_stats.npy (if exist): scale values for preprocessing.

        Args:
            model_name (str): model name as explained above.
        """
        model_item, model_full_name, model, md5sum = self._set_model_item(model_name)
        # set the model specific output path
        output_path = os.path.join(self.output_prefix, model_full_name)
        if os.path.exists(output_path):
            if md5sum is not None:
                md5sum_file = os.path.join(output_path, "hash.md5")
                if os.path.isfile(md5sum_file):
                    with open(md5sum_file, mode="r") as f:
                        if not f.read() == md5sum:
                            print(f" > {model_name} has been updated, clearing model cache...")
                            self.create_dir_and_download_model(model_name, model_item, output_path)
                        else:
                            print(f" > {model_name} is already downloaded.")
                else:
                    print(f" > {model_name} has been updated, clearing model cache...")
                    self.create_dir_and_download_model(model_name, model_item, output_path)
            # if the configs are different, redownload it
            # ToDo: we need a better way to handle it
            if "xtts" in model_name:
                try:
                    self.check_if_configs_are_equal(model_name, model_item, output_path)
                except:
                    pass
            else:
                print(f" > {model_name} is already downloaded.")
        else:
            self.create_dir_and_download_model(model_name, model_item, output_path)

        # find downloaded files
        output_model_path = output_path
        output_config_path = None
        if (
            model not in ["tortoise-v2", "bark"] and "fairseq" not in model_name and "xtts" not in model_name
        ):  # TODO:This is stupid but don't care for now.
            output_model_path, output_config_path = self._find_files(output_path)
        # update paths in the config.json
        self._update_paths(output_path, output_config_path)
        return output_model_path, output_config_path, model_item

    @staticmethod
    def _find_files(output_path: str) -> Tuple[str, str]:
        """Find the model and config files in the output path

        Args:
            output_path (str): path to the model files

        Returns:
            Tuple[str, str]: path to the model file and config file
        """
        model_file = None
        config_file = None
        for file_name in os.listdir(output_path):
            if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
                model_file = os.path.join(output_path, file_name)
            elif file_name == "config.json":
                config_file = os.path.join(output_path, file_name)
        if model_file is None:
            raise ValueError(" [!] Model file not found in the output path")
        if config_file is None:
            raise ValueError(" [!] Config file not found in the output path")
        return model_file, config_file

    @staticmethod
    def _find_speaker_encoder(output_path: str) -> str:
        """Find the speaker encoder file in the output path

        Args:
            output_path (str): path to the model files

        Returns:
            str: path to the speaker encoder file
        """
        speaker_encoder_file = None
        for file_name in os.listdir(output_path):
            if file_name in ["model_se.pth", "model_se.pth.tar"]:
                speaker_encoder_file = os.path.join(output_path, file_name)
        return speaker_encoder_file

    def _update_paths(self, output_path: str, config_path: str) -> None:
        """Update paths for certain files in config.json after download.

        Args:
            output_path (str): local path the model is downloaded to.
            config_path (str): local config.json path.
        """
        output_stats_path = os.path.join(output_path, "scale_stats.npy")
        output_d_vector_file_path = os.path.join(output_path, "speakers.json")
        output_d_vector_file_pth_path = os.path.join(output_path, "speakers.pth")
        output_speaker_ids_file_path = os.path.join(output_path, "speaker_ids.json")
        output_speaker_ids_file_pth_path = os.path.join(output_path, "speaker_ids.pth")
        speaker_encoder_config_path = os.path.join(output_path, "config_se.json")
        speaker_encoder_model_path = self._find_speaker_encoder(output_path)

        # update the scale_path.npy file path in the model config.json
        self._update_path("audio.stats_path", output_stats_path, config_path)

        # update the speakers.json file path in the model config.json to the current path
        self._update_path("d_vector_file", output_d_vector_file_path, config_path)
        self._update_path("d_vector_file", output_d_vector_file_pth_path, config_path)
        self._update_path("model_args.d_vector_file", output_d_vector_file_path, config_path)
        self._update_path("model_args.d_vector_file", output_d_vector_file_pth_path, config_path)

        # update the speaker_ids.json file path in the model config.json to the current path
        self._update_path("speakers_file", output_speaker_ids_file_path, config_path)
        self._update_path("speakers_file", output_speaker_ids_file_pth_path, config_path)
        self._update_path("model_args.speakers_file", output_speaker_ids_file_path, config_path)
        self._update_path("model_args.speakers_file", output_speaker_ids_file_pth_path, config_path)

        # update the speaker_encoder file path in the model config.json to the current path
        self._update_path("speaker_encoder_model_path", speaker_encoder_model_path, config_path)
        self._update_path("model_args.speaker_encoder_model_path", speaker_encoder_model_path, config_path)
        self._update_path("speaker_encoder_config_path", speaker_encoder_config_path, config_path)
        self._update_path("model_args.speaker_encoder_config_path", speaker_encoder_config_path, config_path)

    @staticmethod
    def _update_path(field_name, new_path, config_path):
        """Update the path in the model config.json for the current environment after download"""
        if new_path and os.path.exists(new_path):
            config = load_config(config_path)
            field_names = field_name.split(".")
            if len(field_names) > 1:
                # field name points to a sub-level field
                sub_conf = config
                for fd in field_names[:-1]:
                    if fd in sub_conf:
                        sub_conf = sub_conf[fd]
                    else:
                        return
                if isinstance(sub_conf[field_names[-1]], list):
                    sub_conf[field_names[-1]] = [new_path]
                else:
                    sub_conf[field_names[-1]] = new_path
            else:
                # field name points to a top-level field
                if not field_name in config:
                    return
                if isinstance(config[field_name], list):
                    config[field_name] = [new_path]
                else:
                    config[field_name] = new_path
            config.save_json(config_path)

    @staticmethod
    def _download_zip_file(file_url, output_folder, progress_bar):
        """Download the github releases"""
        # download the file
        r = requests.get(file_url, stream=True)
        # extract the file
        try:
            total_size_in_bytes = int(r.headers.get("content-length", 0))
            block_size = 1024  # 1 Kibibyte
            if progress_bar:
                ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
            temp_zip_name = os.path.join(output_folder, file_url.split("/")[-1])
            with open(temp_zip_name, "wb") as file:
                for data in r.iter_content(block_size):
                    if progress_bar:
                        ModelManager.tqdm_progress.update(len(data))
                    file.write(data)
            with zipfile.ZipFile(temp_zip_name) as z:
                z.extractall(output_folder)
            os.remove(temp_zip_name)  # delete zip after extract
        except zipfile.BadZipFile:
            print(f" > Error: Bad zip file - {file_url}")
            raise zipfile.BadZipFile  # pylint: disable=raise-missing-from
        # move the files to the outer path
        for file_path in z.namelist():
            src_path = os.path.join(output_folder, file_path)
            if os.path.isfile(src_path):
                dst_path = os.path.join(output_folder, os.path.basename(file_path))
                if src_path != dst_path:
                    copyfile(src_path, dst_path)
        # remove redundant (hidden or not) folders
        for file_path in z.namelist():
            if os.path.isdir(os.path.join(output_folder, file_path)):
                rmtree(os.path.join(output_folder, file_path))

    @staticmethod
    def _download_tar_file(file_url, output_folder, progress_bar):
        """Download the github releases"""
        # download the file
        r = requests.get(file_url, stream=True)
        # extract the file
        try:
            total_size_in_bytes = int(r.headers.get("content-length", 0))
            block_size = 1024  # 1 Kibibyte
            if progress_bar:
                ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
            temp_tar_name = os.path.join(output_folder, file_url.split("/")[-1])
            with open(temp_tar_name, "wb") as file:
                for data in r.iter_content(block_size):
                    if progress_bar:
                        ModelManager.tqdm_progress.update(len(data))
                    file.write(data)
            with tarfile.open(temp_tar_name) as t:
                t.extractall(output_folder)
                tar_names = t.getnames()
            os.remove(temp_tar_name)  # delete tar after extract
        except tarfile.ReadError:
            print(f" > Error: Bad tar file - {file_url}")
            raise tarfile.ReadError  # pylint: disable=raise-missing-from
        # move the files to the outer path
        for file_path in os.listdir(os.path.join(output_folder, tar_names[0])):
            src_path = os.path.join(output_folder, tar_names[0], file_path)
            dst_path = os.path.join(output_folder, os.path.basename(file_path))
            if src_path != dst_path:
                copyfile(src_path, dst_path)
        # remove the extracted folder
        rmtree(os.path.join(output_folder, tar_names[0]))

    @staticmethod
    def _download_model_files(file_urls, output_folder, progress_bar):
        """Download the github releases"""
        for file_url in file_urls:
            # download the file
            r = requests.get(file_url, stream=True)
            # extract the file
            bease_filename = file_url.split("/")[-1]
            temp_zip_name = os.path.join(output_folder, bease_filename)
            total_size_in_bytes = int(r.headers.get("content-length", 0))
            block_size = 1024  # 1 Kibibyte
            with open(temp_zip_name, "wb") as file:
                if progress_bar:
                    ModelManager.tqdm_progress = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
                for data in r.iter_content(block_size):
                    if progress_bar:
                        ModelManager.tqdm_progress.update(len(data))
                    file.write(data)

    @staticmethod
    def _check_dict_key(my_dict, key):
        if key in my_dict.keys() and my_dict[key] is not None:
            if not isinstance(key, str):
                return True
            if isinstance(key, str) and len(my_dict[key]) > 0:
                return True
        return False