kevinwang676 commited on
Commit
65de068
1 Parent(s): 5d4265b

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +18 -5
inference.py CHANGED
@@ -10,6 +10,10 @@ from SVCNN import SVCNN
10
  from utils.tools import extract_voiced_area
11
  from utils.extract_pitch import extract_pitch_ref as extract_pitch, coarse_f0
12
 
 
 
 
 
13
  SPEAKER_INFORMATION_WEIGHTS = [
14
  0, 0, 0, 0, 0, 0, # layer 0-5
15
  1.0, 0, 0, 0,
@@ -51,12 +55,21 @@ def svc(model, src_wav_path, ref_wav_path, synth_set_path=None, f0_factor=0., sp
51
  if synth_set_path:
52
  synth_set = torch.load(synth_set_path).to(device)
53
  else:
54
- synth_set = model.get_matching_set(ref_wav_path).to(device)
55
-
56
- if hallucinated_set_path:
57
- hallucinated_set = torch.from_numpy(np.load(hallucinated_set_path)).to(device)
58
- synth_set = torch.cat([synth_set, hallucinated_set], dim=0)
 
 
 
 
 
59
 
 
 
 
 
60
  query_len = query_seq.shape[0]
61
  if len(query_mask) > query_len:
62
  query_mask = query_mask[:query_len]
 
10
  from utils.tools import extract_voiced_area
11
  from utils.extract_pitch import extract_pitch_ref as extract_pitch, coarse_f0
12
 
13
+ from Phoneme_Hallucinator_v2.utils.hparams import HParams
14
+ from Phoneme_Hallucinator_v2.models import get_model as get_hallucinator
15
+ from Phoneme_Hallucinator_v2.scripts.speech_expansion_ins import single_expand
16
+
17
  SPEAKER_INFORMATION_WEIGHTS = [
18
  0, 0, 0, 0, 0, 0, # layer 0-5
19
  1.0, 0, 0, 0,
 
55
  if synth_set_path:
56
  synth_set = torch.load(synth_set_path).to(device)
57
  else:
58
+ synth_set_path = f"matching_set/{ref_name}.pt"
59
+ synth_set = model.get_matching_set(ref_wav_path, out_path=synth_set_path).to(device)
60
+
61
+ if hallucinated_set_path is None:
62
+ params = HParams('Phoneme_Hallucinator_v2/exp/speech_XXL_cond/params.json')
63
+ Hallucinator = get_hallucinator(params)
64
+ Hallucinator.load()
65
+ hallucinated_set = single_expand(synth_set_path, Hallucinator, 15000)
66
+ else:
67
+ hallucinated_set = np.load(hallucinated_set_path)
68
 
69
+ hallucinated_set = torch.from_numpy(hallucinated_set).to(device)
70
+
71
+ synth_set = torch.cat([synth_set, hallucinated_set], dim=0)
72
+
73
  query_len = query_seq.shape[0]
74
  if len(query_mask) > query_len:
75
  query_mask = query_mask[:query_len]