Spaces:
Running
Running
kevinwang676
commited on
Update inference.py
Browse files- inference.py +18 -5
inference.py
CHANGED
@@ -10,6 +10,10 @@ from SVCNN import SVCNN
|
|
10 |
from utils.tools import extract_voiced_area
|
11 |
from utils.extract_pitch import extract_pitch_ref as extract_pitch, coarse_f0
|
12 |
|
|
|
|
|
|
|
|
|
13 |
SPEAKER_INFORMATION_WEIGHTS = [
|
14 |
0, 0, 0, 0, 0, 0, # layer 0-5
|
15 |
1.0, 0, 0, 0,
|
@@ -51,12 +55,21 @@ def svc(model, src_wav_path, ref_wav_path, synth_set_path=None, f0_factor=0., sp
|
|
51 |
if synth_set_path:
|
52 |
synth_set = torch.load(synth_set_path).to(device)
|
53 |
else:
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
|
|
|
|
|
|
59 |
|
|
|
|
|
|
|
|
|
60 |
query_len = query_seq.shape[0]
|
61 |
if len(query_mask) > query_len:
|
62 |
query_mask = query_mask[:query_len]
|
|
|
10 |
from utils.tools import extract_voiced_area
|
11 |
from utils.extract_pitch import extract_pitch_ref as extract_pitch, coarse_f0
|
12 |
|
13 |
+
from Phoneme_Hallucinator_v2.utils.hparams import HParams
|
14 |
+
from Phoneme_Hallucinator_v2.models import get_model as get_hallucinator
|
15 |
+
from Phoneme_Hallucinator_v2.scripts.speech_expansion_ins import single_expand
|
16 |
+
|
17 |
SPEAKER_INFORMATION_WEIGHTS = [
|
18 |
0, 0, 0, 0, 0, 0, # layer 0-5
|
19 |
1.0, 0, 0, 0,
|
|
|
55 |
if synth_set_path:
|
56 |
synth_set = torch.load(synth_set_path).to(device)
|
57 |
else:
|
58 |
+
synth_set_path = f"matching_set/{ref_name}.pt"
|
59 |
+
synth_set = model.get_matching_set(ref_wav_path, out_path=synth_set_path).to(device)
|
60 |
+
|
61 |
+
if hallucinated_set_path is None:
|
62 |
+
params = HParams('Phoneme_Hallucinator_v2/exp/speech_XXL_cond/params.json')
|
63 |
+
Hallucinator = get_hallucinator(params)
|
64 |
+
Hallucinator.load()
|
65 |
+
hallucinated_set = single_expand(synth_set_path, Hallucinator, 15000)
|
66 |
+
else:
|
67 |
+
hallucinated_set = np.load(hallucinated_set_path)
|
68 |
|
69 |
+
hallucinated_set = torch.from_numpy(hallucinated_set).to(device)
|
70 |
+
|
71 |
+
synth_set = torch.cat([synth_set, hallucinated_set], dim=0)
|
72 |
+
|
73 |
query_len = query_seq.shape[0]
|
74 |
if len(query_mask) > query_len:
|
75 |
query_mask = query_mask[:query_len]
|