Spaces:
Sleeping
Sleeping
File size: 5,167 Bytes
c42fe7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# inference/pipeline.py
import os
import json
import sys
from pathlib import Path
from typing import Optional
from utils.hparams import set_hparams, hparams
from inference.ds_variance import DiffSingerVarianceInfer
from inference.ds_acoustic import DiffSingerAcousticInfer
from utils.infer_utils import parse_commandline_spk_mix, trans_key
from webapp.services.parsing.ds_validator import validate_ds
PROJECT_ROOT = Path(__file__).resolve().parent.parent
HF_CHECKPOINTS_DIR = "/tmp/cantussvs_v1/checkpoints"
def run_inference(
ds_path: Path,
output_dir: Path,
title: str,
*,
variance_exp: str = "regular_variance_v1",
acoustic_exp: str = "debug_test",
seed: int = 42,
num_runs: int = 1,
key_shift: int = 0,
gender: Optional[float] = None
) -> Path:
"""
Runs the full pipeline: variance model => acoustic model;
returns the path to the generated WAV.
"""
sys.argv = [
"",
"--config", str(PROJECT_ROOT / "checkpoints" / variance_exp / "config.yaml"),
"--exp_name", variance_exp,
"--infer"
]
set_hparams(print_hparams=False)
# 1) Check input DS exists
if not ds_path.exists():
raise FileNotFoundError(f"Input DS file not found: {ds_path}")
# 2) Load DS params
with open(ds_path, "r", encoding="utf-8") as f:
params = json.load(f)
if not isinstance(params, list):
params = [params]
# Validate loaded DS files
for p in params:
try:
validate_ds(p)
except Exception as e:
raise ValueError(f"Invalid input DS file: {e}")
# Ensure ph_seq present
for p in params:
if "ph_seq" not in p:
text = p.get("text", "")
p["ph_seq"] = " ".join(list(text.replace(" ", "")))
# Transpose
if key_shift != 0:
params = trans_key(params, key_shift)
# Speaker mix
spk_mix = parse_commandline_spk_mix(None) if hparams.get("use_spk_id") else None
for p in params:
if gender is not None and hparams.get("use_key_shift_embed"):
p["gender"] = gender
if spk_mix is not None:
p["spk_mix"] = spk_mix
# ==== Variance Inference ==== #
print(f"[pipeline] Loading variance exp: {variance_exp}")
variance_config_path = os.path.join(HF_CHECKPOINTS_DIR, variance_exp, "config.yaml")
sys.argv = [
"",
"--config", variance_config_path,
"--exp_name", variance_exp,
"--infer"
]
set_hparams(print_hparams=False)
print("[pipeline] Variance hparams keys:", sorted(hparams.keys()))
var_infer = DiffSingerVarianceInfer(ckpt_steps=None, predictions={"dur", "pitch"})
ds_out = output_dir / f"{title}.ds"
var_infer.run_inference(params, out_dir=output_dir, title=title, num_runs=1, seed=seed)
if not ds_out.exists():
raise RuntimeError(f"Variance inference failed; missing {ds_out}")
# Reload params from variance output
with open(ds_out, "r", encoding="utf-8") as f:
params = json.load(f)
if not isinstance(params, list):
params = [params]
# Validate variance output DS
for p in params:
try:
validate_ds(p)
except Exception as e:
raise ValueError(f"Invalid DS after variance inference: {e}")
# ==== Acoustic Inference ==== #
print(f"[pipeline] Loading acoustic exp: {acoustic_exp}")
acoustic_config_path = os.path.join(HF_CHECKPOINTS_DIR, acoustic_exp, "config.yaml")
sys.argv = [
"",
"--config", acoustic_config_path,
"--exp_name", acoustic_exp,
"--infer"
]
set_hparams(print_hparams=False)
print("[pipeline] Acoustic hparams keys:", sorted(hparams.keys()))
ac_infer = DiffSingerAcousticInfer(load_vocoder=True, ckpt_steps=None)
ac_infer.run_inference(params, out_dir=output_dir, title=title, num_runs=num_runs, seed=seed)
wav_out = output_dir / f"{title}.wav"
if not wav_out.exists():
raise RuntimeError(f"Acoustic inference failed; missing {wav_out}")
return wav_out
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run full DiffSinger inference pipeline")
parser.add_argument("ds_path", type=Path)
parser.add_argument("output_dir", type=Path)
parser.add_argument("--title", type=str, default=None)
parser.add_argument("--variance_exp", type=str, default="regular_variance_v1")
parser.add_argument("--acoustic_exp", type=str, default="debug_test")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--num_runs", type=int, default=1)
parser.add_argument("--key_shift", type=int, default=0)
parser.add_argument("--gender", type=float, default=None)
args = parser.parse_args()
title = args.title or args.ds_path.stem
run_inference(
ds_path=args.ds_path,
output_dir=args.output_dir,
title=title,
variance_exp=args.variance_exp,
acoustic_exp=args.acoustic_exp,
seed=args.seed,
num_runs=args.num_runs,
key_shift=args.key_shift,
gender=args.gender,
)
|