kevinwang676 commited on
Commit
88b5ee2
1 Parent(s): 90ce5ab

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +9 -7
inference.py CHANGED
@@ -31,7 +31,7 @@ APPLIED_INFORMATION_WEIGHTS = [
31
  ]
32
 
33
 
34
- def svc(model, src_wav_path, ref_wav_path, out_dir, device, f0_factor, speech_enroll=False):
35
 
36
  wav_name = os.path.basename(src_wav_path).split('.')[0]
37
  ref_name = os.path.basename(ref_wav_path).split('.')[0]
@@ -48,12 +48,14 @@ def svc(model, src_wav_path, ref_wav_path, out_dir, device, f0_factor, speech_en
48
  query_seq = model.get_features(
49
  src_wav_path, weights=synth_weights)
50
 
51
- synth_set_path = f"matching_set/{ref_name}.pt"
52
- synth_set = model.get_matching_set(ref_wav_path, out_path=synth_set_path).to(device)
53
- hallucinated_set_path = f"matching_set/hallucinated_set/{ref_name}_hallucinated_15k.npy"
54
- os.system(f"python Phoneme_Hallucinator_v2/scripts/speech_expansion_ins.py --cfg_file Phoneme_Hallucinator_v2/exp/speech_XXL_cond/params.json --num_samples 15000 --path {synth_set_path} --out_path {hallucinated_set_path}")
55
- hallucinated_set = torch.from_numpy(np.load(hallucinated_set_path)).to(device)
56
- synth_set = torch.cat([synth_set, hallucinated_set], dim=0)
 
 
57
 
58
  query_len = query_seq.shape[0]
59
  if len(query_mask) > query_len:
 
31
  ]
32
 
33
 
34
+ def svc(model, src_wav_path, ref_wav_path, synth_set_path=None, f0_factor=0., speech_enroll=False, out_dir="output", hallucinated_set_path=None, device='cpu'):
35
 
36
  wav_name = os.path.basename(src_wav_path).split('.')[0]
37
  ref_name = os.path.basename(ref_wav_path).split('.')[0]
 
48
  query_seq = model.get_features(
49
  src_wav_path, weights=synth_weights)
50
 
51
+ if synth_set_path:
52
+ synth_set = torch.load(synth_set_path).to(device)
53
+ else:
54
+ synth_set = model.get_matching_set(ref_wav_path).to(device)
55
+
56
+ if hallucinated_set_path:
57
+ hallucinated_set = torch.from_numpy(np.load(hallucinated_set_path)).to(device)
58
+ synth_set = torch.cat([synth_set, hallucinated_set], dim=0)
59
 
60
  query_len = query_seq.shape[0]
61
  if len(query_mask) > query_len: