Spaces:
Running
on
Zero
Running
on
Zero
| import ctypes | |
| import io | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import IO, Union | |
| import ffmpeg | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| from torch.serialization import _opener | |
| from tools.i18n.i18n import I18nAuto | |
| i18n = I18nAuto(language=os.environ.get("language", "Auto")) | |
| def load_audio(file, sr): | |
| try: | |
| # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26 | |
| # This launches a subprocess to decode audio while down-mixing and resampling as necessary. | |
| # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed. | |
| file = clean_path(file) # 防止小白拷路径头尾带了空格和"和回车 | |
| if os.path.exists(file) is False: | |
| raise RuntimeError("You input a wrong audio path that does not exists, please fix it!") | |
| out, _ = ( | |
| ffmpeg.input(file, threads=0) | |
| .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) | |
| .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True) | |
| ) | |
| except Exception: | |
| out, _ = ( | |
| ffmpeg.input(file, threads=0) | |
| .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr) | |
| .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True) | |
| ) # Expose the Error | |
| raise RuntimeError(i18n("音频加载失败")) | |
| return np.frombuffer(out, np.float32).flatten() | |
| def clean_path(path_str: str) -> str: | |
| if path_str.endswith(("\\", "/")): | |
| return clean_path(path_str[0:-1]) | |
| path_str = path_str.replace("/", os.sep).replace("\\", os.sep) | |
| return path_str.strip(" '\n\"\u202a") | |
| def check_for_existance(file_list: list = None, is_train=False, is_dataset_processing=False): | |
| files_status = [] | |
| if is_train is True and file_list: | |
| file_list.append(os.path.join(file_list[0], "2-name2text.txt")) | |
| file_list.append(os.path.join(file_list[0], "3-bert")) | |
| file_list.append(os.path.join(file_list[0], "4-cnhubert")) | |
| file_list.append(os.path.join(file_list[0], "5-wav32k")) | |
| file_list.append(os.path.join(file_list[0], "6-name2semantic.tsv")) | |
| for file in file_list: | |
| if os.path.exists(file): | |
| files_status.append(True) | |
| else: | |
| files_status.append(False) | |
| if sum(files_status) != len(files_status): | |
| if is_train: | |
| for file, status in zip(file_list, files_status): | |
| if status: | |
| pass | |
| else: | |
| gr.Warning(file) | |
| gr.Warning(i18n("以下文件或文件夹不存在")) | |
| return False | |
| elif is_dataset_processing: | |
| if files_status[0]: | |
| return True | |
| elif not files_status[0]: | |
| gr.Warning(file_list[0]) | |
| elif not files_status[1] and file_list[1]: | |
| gr.Warning(file_list[1]) | |
| gr.Warning(i18n("以下文件或文件夹不存在")) | |
| return False | |
| else: | |
| if file_list[0]: | |
| gr.Warning(file_list[0]) | |
| gr.Warning(i18n("以下文件或文件夹不存在")) | |
| else: | |
| gr.Warning(i18n("路径不能为空")) | |
| return False | |
| return True | |
| def check_details(path_list=None, is_train=False, is_dataset_processing=False): | |
| if is_dataset_processing: | |
| list_path, audio_path = path_list | |
| if not list_path.endswith(".list"): | |
| gr.Warning(i18n("请填入正确的List路径")) | |
| return | |
| if audio_path: | |
| if not os.path.isdir(audio_path): | |
| gr.Warning(i18n("请填入正确的音频文件夹路径")) | |
| return | |
| with open(list_path, "r", encoding="utf8") as f: | |
| line = f.readline().strip("\n").split("\n") | |
| wav_name, _, __, ___ = line[0].split("|") | |
| wav_name = clean_path(wav_name) | |
| if audio_path != "" and audio_path != None: | |
| wav_name = os.path.basename(wav_name) | |
| wav_path = "%s/%s" % (audio_path, wav_name) | |
| else: | |
| wav_path = wav_name | |
| if os.path.exists(wav_path): | |
| ... | |
| else: | |
| gr.Warning(wav_path + i18n("路径错误")) | |
| return | |
| if is_train: | |
| path_list.append(os.path.join(path_list[0], "2-name2text.txt")) | |
| path_list.append(os.path.join(path_list[0], "4-cnhubert")) | |
| path_list.append(os.path.join(path_list[0], "5-wav32k")) | |
| path_list.append(os.path.join(path_list[0], "6-name2semantic.tsv")) | |
| phone_path, hubert_path, wav_path, semantic_path = path_list[1:] | |
| with open(phone_path, "r", encoding="utf-8") as f: | |
| if f.read(1): | |
| ... | |
| else: | |
| gr.Warning(i18n("缺少音素数据集")) | |
| if os.listdir(hubert_path): | |
| ... | |
| else: | |
| gr.Warning(i18n("缺少Hubert数据集")) | |
| if os.listdir(wav_path): | |
| ... | |
| else: | |
| gr.Warning(i18n("缺少音频数据集")) | |
| df = pd.read_csv(semantic_path, delimiter="\t", encoding="utf-8") | |
| if len(df) >= 1: | |
| ... | |
| else: | |
| gr.Warning(i18n("缺少语义数据集")) | |
| def load_cudnn(): | |
| import torch | |
| if not torch.cuda.is_available(): | |
| print("[INFO] CUDA is not available, skipping cuDNN setup.") | |
| return | |
| if sys.platform == "win32": | |
| torch_lib_dir = Path(torch.__file__).parent / "lib" | |
| if torch_lib_dir.exists(): | |
| os.add_dll_directory(str(torch_lib_dir)) | |
| print(f"[INFO] Added DLL directory: {torch_lib_dir}") | |
| matching_files = sorted(torch_lib_dir.glob("cudnn_cnn*.dll")) | |
| if not matching_files: | |
| print(f"[ERROR] No cudnn_cnn*.dll found in {torch_lib_dir}") | |
| return | |
| for dll_path in matching_files: | |
| dll_name = os.path.basename(dll_path) | |
| try: | |
| ctypes.CDLL(dll_name) | |
| print(f"[INFO] Loaded: {dll_name}") | |
| except OSError as e: | |
| print(f"[WARNING] Failed to load {dll_name}: {e}") | |
| else: | |
| print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}") | |
| elif sys.platform == "linux": | |
| site_packages = Path(torch.__file__).resolve().parents[1] | |
| cudnn_dir = site_packages / "nvidia" / "cudnn" / "lib" | |
| if not cudnn_dir.exists(): | |
| print(f"[ERROR] cudnn dir not found: {cudnn_dir}") | |
| return | |
| matching_files = sorted(cudnn_dir.glob("libcudnn_cnn*.so*")) | |
| if not matching_files: | |
| print(f"[ERROR] No libcudnn_cnn*.so* found in {cudnn_dir}") | |
| return | |
| for so_path in matching_files: | |
| try: | |
| ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) # type: ignore | |
| print(f"[INFO] Loaded: {so_path}") | |
| except OSError as e: | |
| print(f"[WARNING] Failed to load {so_path}: {e}") | |
| def load_nvrtc(): | |
| import torch | |
| if not torch.cuda.is_available(): | |
| print("[INFO] CUDA is not available, skipping nvrtc setup.") | |
| return | |
| if sys.platform == "win32": | |
| torch_lib_dir = Path(torch.__file__).parent / "lib" | |
| if torch_lib_dir.exists(): | |
| os.add_dll_directory(str(torch_lib_dir)) | |
| print(f"[INFO] Added DLL directory: {torch_lib_dir}") | |
| matching_files = sorted(torch_lib_dir.glob("nvrtc*.dll")) | |
| if not matching_files: | |
| print(f"[ERROR] No nvrtc*.dll found in {torch_lib_dir}") | |
| return | |
| for dll_path in matching_files: | |
| dll_name = os.path.basename(dll_path) | |
| try: | |
| ctypes.CDLL(dll_name) | |
| print(f"[INFO] Loaded: {dll_name}") | |
| except OSError as e: | |
| print(f"[WARNING] Failed to load {dll_name}: {e}") | |
| else: | |
| print(f"[WARNING] Torch lib directory not found: {torch_lib_dir}") | |
| elif sys.platform == "linux": | |
| site_packages = Path(torch.__file__).resolve().parents[1] | |
| nvrtc_dir = site_packages / "nvidia" / "cuda_nvrtc" / "lib" | |
| if not nvrtc_dir.exists(): | |
| print(f"[ERROR] nvrtc dir not found: {nvrtc_dir}") | |
| return | |
| matching_files = sorted(nvrtc_dir.glob("libnvrtc*.so*")) | |
| if not matching_files: | |
| print(f"[ERROR] No libnvrtc*.so* found in {nvrtc_dir}") | |
| return | |
| for so_path in matching_files: | |
| try: | |
| ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) # type: ignore | |
| print(f"[INFO] Loaded: {so_path}") | |
| except OSError as e: | |
| print(f"[WARNING] Failed to load {so_path}: {e}") | |
| class DictToAttrRecursive(dict): | |
| def __init__(self, input_dict): | |
| super().__init__(input_dict) | |
| for key, value in input_dict.items(): | |
| if isinstance(value, dict): | |
| value = DictToAttrRecursive(value) | |
| self[key] = value | |
| setattr(self, key, value) | |
| def __getattr__(self, item): | |
| try: | |
| return self[item] | |
| except KeyError: | |
| raise AttributeError(f"Attribute {item} not found") | |
| def __setattr__(self, key, value): | |
| if isinstance(value, dict): | |
| value = DictToAttrRecursive(value) | |
| super(DictToAttrRecursive, self).__setitem__(key, value) | |
| super().__setattr__(key, value) | |
| def __delattr__(self, item): | |
| try: | |
| del self[item] | |
| except KeyError: | |
| raise AttributeError(f"Attribute {item} not found") | |
| class _HeadOverlay(io.IOBase, IO): | |
| def __init__(self, base: IO[bytes], patch: bytes = b"PK", offset: int = 0): | |
| super(io.IOBase, self).__init__() | |
| if not base.readable(): | |
| raise ValueError("Base stream must be readable") | |
| self._base = base | |
| self._patch = patch | |
| self._off = offset | |
| def readable(self) -> bool: | |
| return True | |
| def writable(self) -> bool: | |
| return False | |
| def seekable(self) -> bool: | |
| try: | |
| return self._base.seekable() | |
| except Exception: | |
| return False | |
| def tell(self) -> int: | |
| return self._base.tell() | |
| def seek(self, pos: int, whence: int = os.SEEK_SET) -> int: | |
| return self._base.seek(pos, whence) | |
| def read(self, size: int = -1) -> bytes: | |
| start = self._base.tell() | |
| data = self._base.read(size) | |
| if not data: | |
| return data | |
| end = start + len(data) | |
| ps, pe = self._off, self._off + len(self._patch) | |
| a, b = max(start, ps), min(end, pe) | |
| if a < b: | |
| buf = bytearray(data) | |
| s_rel = a - start | |
| e_rel = b - start | |
| p_rel = a - ps | |
| buf[s_rel:e_rel] = self._patch[p_rel : p_rel + (e_rel - s_rel)] | |
| return bytes(buf) | |
| return data | |
| def readinto(self, b) -> int: | |
| start: int = self._base.tell() | |
| nread = self._base.readinto(b) # type: ignore | |
| end = start + nread | |
| ps, pe = self._off, self._off + len(self._patch) | |
| a, c = max(start, ps), min(end, pe) | |
| if a < c: | |
| mv = memoryview(b) | |
| s_rel = a - start | |
| e_rel = c - start | |
| p_rel = a - ps | |
| mv[s_rel:e_rel] = self._patch[p_rel : p_rel + (e_rel - s_rel)] | |
| return nread | |
| def close(self) -> None: | |
| try: | |
| self._base.close() | |
| finally: | |
| super().close() | |
| def flush(self) -> None: | |
| try: | |
| self._base.flush() | |
| except Exception: | |
| pass | |
| def write(self, b) -> int: | |
| raise io.UnsupportedOperation("not writable") | |
| def raw(self): | |
| return self._base | |
| def __getattr__(self, name): | |
| return None | |
| class _open_file(_opener[IO[bytes]]): | |
| def __init__(self, name: Union[str, os.PathLike[str]], mode: str) -> None: | |
| f = open(name, mode) | |
| if "r" in mode: | |
| f = _HeadOverlay(f, b"PK", 0) | |
| super().__init__(f) | |
| def __exit__(self, *args): | |
| self.file_like.close() | |