zhengr's picture
init
c02bdcd
import os
from pathlib import Path
import hashlib
import requests
from io import BytesIO
from typing import Dict, Tuple, Optional
from mmap import mmap, ACCESS_READ
from .log import logger
def sha256(fileno: int) -> str:
data = mmap(fileno, 0, access=ACCESS_READ)
h = hashlib.sha256(data).hexdigest()
del data
return h
def check_model(
dir_name: Path, model_name: str, hash: str, remove_incorrect=False
) -> bool:
target = dir_name / model_name
relname = target.as_posix()
logger.get_logger().debug(f"checking {relname}...")
if not os.path.exists(target):
logger.get_logger().info(f"{target} not exist.")
return False
with open(target, "rb") as f:
digest = sha256(f.fileno())
bakfile = f"{target}.bak"
if digest != hash:
logger.get_logger().warning(f"{target} sha256 hash mismatch.")
logger.get_logger().info(f"expected: {hash}")
logger.get_logger().info(f"real val: {digest}")
if remove_incorrect:
if not os.path.exists(bakfile):
os.rename(str(target), bakfile)
else:
os.remove(str(target))
return False
if remove_incorrect and os.path.exists(bakfile):
os.remove(bakfile)
return True
def check_folder(
base_dir: Path,
*innder_dirs: str,
names: Tuple[str],
sha256_map: Dict[str, str],
update=False,
) -> bool:
key = "sha256_"
current_dir = base_dir
for d in innder_dirs:
current_dir /= d
key += f"{d}_"
for model in names:
menv = model.replace(".", "_")
if not check_model(current_dir, model, sha256_map[f"{key}{menv}"], update):
return False
return True
def check_all_assets(base_dir: Path, sha256_map: Dict[str, str], update=False) -> bool:
logger.get_logger().info("checking assets...")
if not check_folder(
base_dir,
"asset",
names=(
"Decoder.pt",
"DVAE_full.pt",
"Embed.safetensors",
"Vocos.pt",
),
sha256_map=sha256_map,
update=update,
):
return False
if not check_folder(
base_dir,
"asset",
"gpt",
names=(
"config.json",
"model.safetensors",
),
sha256_map=sha256_map,
update=update,
):
return False
if not check_folder(
base_dir,
"asset",
"tokenizer",
names=(
"special_tokens_map.json",
"tokenizer_config.json",
"tokenizer.json",
),
sha256_map=sha256_map,
update=update,
):
return False
logger.get_logger().info("all assets are already latest.")
return True
def download_and_extract_tar_gz(
url: str, folder: str, headers: Optional[Dict[str, str]] = None
):
import tarfile
logger.get_logger().info(f"downloading {url}")
response = requests.get(url, headers=headers, stream=True, timeout=(10, 3))
with BytesIO() as out_file:
out_file.write(response.content)
out_file.seek(0)
logger.get_logger().info(f"downloaded.")
with tarfile.open(fileobj=out_file, mode="r:gz") as tar:
tar.extractall(folder)
logger.get_logger().info(f"extracted into {folder}")
def download_and_extract_zip(
url: str, folder: str, headers: Optional[Dict[str, str]] = None
):
import zipfile
logger.get_logger().info(f"downloading {url}")
response = requests.get(url, headers=headers, stream=True, timeout=(10, 3))
with BytesIO() as out_file:
out_file.write(response.content)
out_file.seek(0)
logger.get_logger().info(f"downloaded.")
with zipfile.ZipFile(out_file) as zip_ref:
zip_ref.extractall(folder)
logger.get_logger().info(f"extracted into {folder}")
def download_dns_yaml(url: str, folder: str, headers: Dict[str, str]):
logger.get_logger().info(f"downloading {url}")
response = requests.get(url, headers=headers, stream=True, timeout=(100, 3))
with open(os.path.join(folder, "dns.yaml"), "wb") as out_file:
out_file.write(response.content)
logger.get_logger().info(f"downloaded into {folder}")
def download_all_assets(tmpdir: str, version="0.2.8"):
import subprocess
import platform
archs = {
"aarch64": "arm64",
"armv8l": "arm64",
"arm64": "arm64",
"x86": "386",
"i386": "386",
"i686": "386",
"386": "386",
"x86_64": "amd64",
"x64": "amd64",
"amd64": "amd64",
}
system_type = platform.system().lower()
architecture = platform.machine().lower()
is_win = system_type == "windows"
architecture = archs.get(architecture, None)
if not architecture:
logger.get_logger().error(f"architecture {architecture} is not supported")
exit(1)
try:
BASE_URL = "https://github.com/fumiama/RVC-Models-Downloader/releases/download/"
suffix = "zip" if is_win else "tar.gz"
RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}"
cmdfile = os.path.join(tmpdir, "rvcmd")
if is_win:
download_and_extract_zip(RVCMD_URL, tmpdir)
cmdfile += ".exe"
else:
download_and_extract_tar_gz(RVCMD_URL, tmpdir)
os.chmod(cmdfile, 0o755)
subprocess.run([cmdfile, "-notui", "-w", "0", "assets/chtts"])
except Exception:
BASE_URL = (
"https://gitea.seku.su/fumiama/RVC-Models-Downloader/releases/download/"
)
suffix = "zip" if is_win else "tar.gz"
RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}"
download_dns_yaml(
"https://gitea.seku.su/fumiama/RVC-Models-Downloader/raw/branch/main/dns.yaml",
tmpdir,
headers={
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36 Edg/128.0.0.0"
},
)
cmdfile = os.path.join(tmpdir, "rvcmd")
if is_win:
download_and_extract_zip(RVCMD_URL, tmpdir)
cmdfile += ".exe"
else:
download_and_extract_tar_gz(RVCMD_URL, tmpdir)
os.chmod(cmdfile, 0o755)
subprocess.run(
[
cmdfile,
"-notui",
"-w",
"0",
"-dns",
os.path.join(tmpdir, "dns.yaml"),
"assets/chtts",
]
)