Spaces:
Runtime error
Runtime error
File size: 3,717 Bytes
1378843 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import logging
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import Any, Generator
import boto3
from pydantic import BaseModel, Field, TypeAdapter
from tqdm import tqdm
from .utils import data_dir, env_str
@dataclass
class S3VoiceObj:
key: str
size: int
@property
def name(self) -> str:
return self.key.split("/")[-1]
@classmethod
def from_s3_obj(cls, obj: Any) -> "S3VoiceObj":
return S3VoiceObj(key=obj["Key"], size=obj["Size"])
class Voice(BaseModel):
name: str
model: str
tts: str
index: str = ""
autotune: float | None = None
clean: float | None = 0.5
upscale: bool = False
pitch: int = 0
filter_radius: int = 3
index_rate: float = 0.75
rms_mix_rate: float = 1
protect: float = 0.5
hop_length: int = 128
f0_method: str = "rmvpe"
embedder_model: str = "contentvec"
class TTSVoice(BaseModel):
name: str = Field(alias="ShortName")
class VoiceManager:
def __init__(self) -> None:
self.s3 = boto3.client("s3")
self.bucket = env_str("BUCKET")
self.prefix = env_str("VOICES_KEY_PREFIX")
self.voices_dir = Path(data_dir("voices"))
def _iter_s3_objects(self) -> Generator[S3VoiceObj, None, None]:
response = self.s3.list_objects_v2(Bucket=self.bucket, Prefix=self.prefix)
for obj in response.get("Contents", []):
yield S3VoiceObj.from_s3_obj(obj)
def get_voices_size_if_missing(self) -> int:
"""
Calculate the total size of the voice files only if they do not exist locally.
"""
total_size = 0
paths: set[Path] = set()
for obj in self._iter_s3_objects():
destination_path = self.voices_dir / obj.name
paths.add(destination_path)
if not destination_path.exists() or destination_path.stat().st_size != obj.size:
total_size += obj.size
for path in self.voices_dir.glob("*"):
if path not in paths:
path.unlink()
return total_size
def download_voice_files(self, progress_bar: tqdm) -> None:
"""
Download all voice files from s3 updating the global progress bar.
"""
def callback(bytes_amount: int) -> None:
progress_bar.update(bytes_amount)
for obj in self._iter_s3_objects():
destination_path = self.voices_dir / obj.name
if not destination_path.exists() or destination_path.stat().st_size != obj.size:
self.s3.download_file(Bucket=self.bucket, Key=obj.key, Filename=destination_path, Callback=callback)
@cached_property
def tts_voices(self) -> dict[str, TTSVoice]:
path = Path("rvc/lib/tools/tts_voices.json")
voices = TypeAdapter(list[TTSVoice]).validate_json(path.read_bytes())
return {v.name: v for v in voices}
@property
def voice_names(self) -> list[str]:
return list(self.voices.keys())
@cached_property
def voices(self) -> dict[str, Voice]:
rv = {}
for path in self.voices_dir.glob("*.json"):
voice = Voice.model_validate_json(path.read_bytes())
model_path = self.voices_dir / f"{voice.model}"
if not model_path.exists():
logging.warning("Voice %s missing model %s", voice.name, voice.model)
elif voice.tts not in self.tts_voices:
logging.warning("Voice %s references invalid tts %s", voice.name, voice.model)
else:
voice.model = str(model_path)
rv[voice.name] = voice
return rv
voice_manager = VoiceManager()
|