|
|
""" |
|
|
Model Downloader for SYSPIN TTS Models |
|
|
Downloads models from Hugging Face Hub |
|
|
""" |
|
|
|
|
|
import os |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Optional, List |
|
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
|
from tqdm import tqdm |
|
|
|
|
|
from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ModelDownloader: |
|
|
"""Downloads and manages SYSPIN TTS models from Hugging Face""" |
|
|
|
|
|
def __init__(self, models_dir: str = MODELS_DIR): |
|
|
self.models_dir = Path(models_dir) |
|
|
self.models_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def download_model(self, voice_key: str, force: bool = False) -> Path: |
|
|
""" |
|
|
Download a specific voice model |
|
|
|
|
|
Args: |
|
|
voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male', 'bn_female') |
|
|
force: Re-download even if exists |
|
|
|
|
|
Returns: |
|
|
Path to downloaded model directory |
|
|
""" |
|
|
if voice_key not in LANGUAGE_CONFIGS: |
|
|
raise ValueError( |
|
|
f"Unknown voice: {voice_key}. Available: {list(LANGUAGE_CONFIGS.keys())}" |
|
|
) |
|
|
|
|
|
config = LANGUAGE_CONFIGS[voice_key] |
|
|
model_dir = self.models_dir / voice_key |
|
|
|
|
|
|
|
|
model_path = model_dir / config.model_filename |
|
|
chars_path = model_dir / config.chars_filename |
|
|
extra_path = model_dir / "extra.py" |
|
|
|
|
|
if not force and model_path.exists() and chars_path.exists(): |
|
|
logger.info(f"Model {voice_key} already downloaded at {model_dir}") |
|
|
return model_dir |
|
|
|
|
|
logger.info(f"Downloading {voice_key} from {config.hf_model_id}...") |
|
|
|
|
|
|
|
|
model_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
try: |
|
|
|
|
|
snapshot_download( |
|
|
repo_id=config.hf_model_id, |
|
|
local_dir=str(model_dir), |
|
|
local_dir_use_symlinks=False, |
|
|
allow_patterns=["*.pt", "*.pth", "*.txt", "*.py", "*.json"], |
|
|
) |
|
|
logger.info(f"Successfully downloaded {voice_key} to {model_dir}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to download {voice_key}: {e}") |
|
|
raise |
|
|
|
|
|
return model_dir |
|
|
|
|
|
def download_all_models(self, force: bool = False) -> List[Path]: |
|
|
"""Download all available models""" |
|
|
downloaded = [] |
|
|
|
|
|
for voice_key in tqdm(LANGUAGE_CONFIGS.keys(), desc="Downloading models"): |
|
|
try: |
|
|
path = self.download_model(voice_key, force=force) |
|
|
downloaded.append(path) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to download {voice_key}: {e}") |
|
|
|
|
|
return downloaded |
|
|
|
|
|
def download_language(self, lang_code: str, force: bool = False) -> List[Path]: |
|
|
"""Download all voices for a specific language""" |
|
|
downloaded = [] |
|
|
|
|
|
for voice_key, config in LANGUAGE_CONFIGS.items(): |
|
|
if config.code == lang_code: |
|
|
try: |
|
|
path = self.download_model(voice_key, force=force) |
|
|
downloaded.append(path) |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to download {voice_key}: {e}") |
|
|
|
|
|
return downloaded |
|
|
|
|
|
def get_model_path(self, voice_key: str) -> Optional[Path]: |
|
|
"""Get path to a downloaded model""" |
|
|
if voice_key not in LANGUAGE_CONFIGS: |
|
|
return None |
|
|
|
|
|
config = LANGUAGE_CONFIGS[voice_key] |
|
|
model_path = self.models_dir / voice_key / config.model_filename |
|
|
|
|
|
if model_path.exists(): |
|
|
return model_path.parent |
|
|
return None |
|
|
|
|
|
def list_downloaded_models(self) -> List[str]: |
|
|
"""List all downloaded models""" |
|
|
downloaded = [] |
|
|
|
|
|
for voice_key, config in LANGUAGE_CONFIGS.items(): |
|
|
model_path = self.models_dir / voice_key / config.model_filename |
|
|
if model_path.exists(): |
|
|
downloaded.append(voice_key) |
|
|
|
|
|
return downloaded |
|
|
|
|
|
def get_model_size(self, voice_key: str) -> Optional[int]: |
|
|
"""Get size of downloaded model in bytes""" |
|
|
model_path = self.get_model_path(voice_key) |
|
|
if not model_path: |
|
|
return None |
|
|
|
|
|
total_size = 0 |
|
|
for f in model_path.iterdir(): |
|
|
if f.is_file(): |
|
|
total_size += f.stat().st_size |
|
|
|
|
|
return total_size |
|
|
|
|
|
|
|
|
def download_models_cli(): |
|
|
"""CLI entry point for downloading models""" |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Download SYSPIN TTS models") |
|
|
parser.add_argument( |
|
|
"--voice", type=str, help="Specific voice to download (e.g., hi_male)" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--lang", type=str, help="Download all voices for a language (e.g., hi)" |
|
|
) |
|
|
parser.add_argument("--all", action="store_true", help="Download all models") |
|
|
parser.add_argument("--list", action="store_true", help="List available models") |
|
|
parser.add_argument("--force", action="store_true", help="Force re-download") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
downloader = ModelDownloader() |
|
|
|
|
|
if args.list: |
|
|
print("Available voices:") |
|
|
for key, config in LANGUAGE_CONFIGS.items(): |
|
|
downloaded = "✓" if downloader.get_model_path(key) else " " |
|
|
print(f" [{downloaded}] {key}: {config.name} ({config.code})") |
|
|
return |
|
|
|
|
|
if args.voice: |
|
|
downloader.download_model(args.voice, force=args.force) |
|
|
elif args.lang: |
|
|
downloader.download_language(args.lang, force=args.force) |
|
|
elif args.all: |
|
|
downloader.download_all_models(force=args.force) |
|
|
else: |
|
|
parser.print_help() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
download_models_cli() |
|
|
|