maxmax20160403 commited on
Commit
5ab1e9c
1 Parent(s): a383264
Files changed (1) hide show
  1. hubert/inference.py +2 -4
hubert/inference.py CHANGED
@@ -11,7 +11,6 @@ from hubert import hubert_model
11
  def load_model(path, device):
12
  model = hubert_model.hubert_soft(path)
13
  model.eval()
14
- model.half()
15
  model.to(device)
16
  return model
17
 
@@ -19,7 +18,7 @@ def load_model(path, device):
19
  def pred_vec(model, wavPath, vecPath, device):
20
  feats = load_audio(wavPath)
21
  feats = torch.from_numpy(feats).to(device)
22
- feats = feats[None, None, :].half()
23
  with torch.no_grad():
24
  vec = model.units(feats).squeeze().data.cpu().float().numpy()
25
  # print(vec.shape) # [length, dim=256] hop=320
@@ -38,8 +37,7 @@ if __name__ == "__main__":
38
  wavPath = args.wav
39
  vecPath = args.vec
40
 
41
- assert torch.cuda.is_available()
42
- device = "cuda"
43
  hubert = load_model(os.path.join(
44
  "hubert_pretrain", "hubert-soft-0d54a1f4.pt"), device)
45
  pred_vec(hubert, wavPath, vecPath, device)
 
11
  def load_model(path, device):
12
  model = hubert_model.hubert_soft(path)
13
  model.eval()
 
14
  model.to(device)
15
  return model
16
 
 
18
  def pred_vec(model, wavPath, vecPath, device):
19
  feats = load_audio(wavPath)
20
  feats = torch.from_numpy(feats).to(device)
21
+ feats = feats[None, None, :]
22
  with torch.no_grad():
23
  vec = model.units(feats).squeeze().data.cpu().float().numpy()
24
  # print(vec.shape) # [length, dim=256] hop=320
 
37
  wavPath = args.wav
38
  vecPath = args.vec
39
 
40
+ device = "cpu"
 
41
  hubert = load_model(os.path.join(
42
  "hubert_pretrain", "hubert-soft-0d54a1f4.pt"), device)
43
  pred_vec(hubert, wavPath, vecPath, device)