JacobLinCool commited on
Commit
a0ad823
·
1 Parent(s): 28cee66

feat: hubert features

Browse files
Files changed (2) hide show
  1. app.py +15 -1
  2. infer/modules/train/extract_feature_print.py +72 -113
app.py CHANGED
@@ -7,6 +7,7 @@ import shutil
7
  from glob import glob
8
  from infer.modules.train.preprocess import PreProcess
9
  from infer.modules.train.extract.extract_f0_rmvpe import FeatureInput
 
10
  from infer.modules.train.train import train
11
  from infer.lib.train.process_ckpt import extract_small_model
12
  from zero import zero
@@ -60,6 +61,19 @@ def extract_features(exp_dir: str) -> str:
60
  fi.logfile.seek(0)
61
  log = fi.logfile.read()
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if err:
64
  log = f"Error: {err}\n{log}"
65
 
@@ -195,7 +209,7 @@ with gr.Blocks() as app:
195
  with gr.Column():
196
  train_btn = gr.Button(value="Train", variant="primary")
197
  with gr.Column():
198
- latest_model = gr.File(label="Latest model")
199
 
200
  with gr.Row():
201
  with gr.Column():
 
7
  from glob import glob
8
  from infer.modules.train.preprocess import PreProcess
9
  from infer.modules.train.extract.extract_f0_rmvpe import FeatureInput
10
+ from infer.modules.train.extract_feature_print import HubertFeatureExtractor
11
  from infer.modules.train.train import train
12
  from infer.lib.train.process_ckpt import extract_small_model
13
  from zero import zero
 
61
  fi.logfile.seek(0)
62
  log = fi.logfile.read()
63
 
64
+ if err:
65
+ log = f"Error: {err}\n{log}"
66
+ return log
67
+
68
+ hfe = HubertFeatureExtractor(exp_dir)
69
+ try:
70
+ hfe.run()
71
+ except Exception as e:
72
+ err = e
73
+
74
+ hfe.logfile.seek(0)
75
+ log += hfe.logfile.read()
76
+
77
  if err:
78
  log = f"Error: {err}\n{log}"
79
 
 
209
  with gr.Column():
210
  train_btn = gr.Button(value="Train", variant="primary")
211
  with gr.Column():
212
+ latest_model = gr.File(label="Latest checkpoint")
213
 
214
  with gr.Row():
215
  with gr.Column():
infer/modules/train/extract_feature_print.py CHANGED
@@ -1,65 +1,30 @@
1
  import os
2
- import sys
3
  import traceback
4
-
5
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
6
- os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
7
-
8
- device = sys.argv[1]
9
- n_part = int(sys.argv[2])
10
- i_part = int(sys.argv[3])
11
- if len(sys.argv) == 7:
12
- exp_dir = sys.argv[4]
13
- version = sys.argv[5]
14
- is_half = sys.argv[6].lower() == "true"
15
- else:
16
- i_gpu = sys.argv[4]
17
- exp_dir = sys.argv[5]
18
- os.environ["CUDA_VISIBLE_DEVICES"] = str(i_gpu)
19
- version = sys.argv[6]
20
- is_half = sys.argv[7].lower() == "true"
21
  import fairseq
22
  import numpy as np
23
  import soundfile as sf
24
  import torch
25
  import torch.nn.functional as F
26
 
27
- if "privateuseone" not in device:
28
- device = "cpu"
29
- if torch.cuda.is_available():
30
- device = "cuda"
31
- elif torch.backends.mps.is_available():
32
- device = "mps"
33
- else:
34
- import torch_directml
35
-
36
- device = torch_directml.device(torch_directml.default_device())
37
-
38
- def forward_dml(ctx, x, scale):
39
- ctx.scale = scale
40
- res = x.clone().detach()
41
- return res
42
-
43
- fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
44
-
45
- f = open("%s/extract_f0_feature.log" % exp_dir, "a+")
46
-
47
-
48
- def printt(strr):
49
- print(strr)
50
- f.write("%s\n" % strr)
51
- f.flush()
52
 
 
 
 
 
 
53
 
54
- printt(" ".join(sys.argv))
55
  model_path = "assets/hubert/hubert_base.pt"
56
-
57
- printt("exp_dir: " + exp_dir)
58
- wavPath = "%s/1_16k_wavs" % exp_dir
59
- outPath = (
60
- "%s/3_feature256" % exp_dir if version == "v1" else "%s/3_feature768" % exp_dir
61
  )
62
- os.makedirs(outPath, exist_ok=True)
 
 
 
 
 
 
63
 
64
 
65
  # wave must be 16k, hop_size=320
@@ -77,66 +42,60 @@ def readwave(wav_path, normalize=False):
77
  return feats
78
 
79
 
80
- # HuBERT model
81
- printt("load model(s) from {}".format(model_path))
82
- # if hubert model is exist
83
- if os.access(model_path, os.F_OK) == False:
84
- printt(
85
- "Error: Extracting is shut down because %s does not exist, you may download it from https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main"
86
- % model_path
87
- )
88
- exit(0)
89
- models, saved_cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
90
- [model_path],
91
- suffix="",
92
- )
93
- model = models[0]
94
- model = model.to(device)
95
- printt("move model to %s" % device)
96
- if is_half:
97
- if device not in ["mps", "cpu"]:
98
- model = model.half()
99
- model.eval()
100
-
101
- todo = sorted(list(os.listdir(wavPath)))[i_part::n_part]
102
- n = max(1, len(todo) // 10) # 最多打印十条
103
- if len(todo) == 0:
104
- printt("no-feature-todo")
105
- else:
106
- printt("all-feature-%s" % len(todo))
107
- for idx, file in enumerate(todo):
108
- try:
109
- if file.endswith(".wav"):
110
- wav_path = "%s/%s" % (wavPath, file)
111
- out_path = "%s/%s" % (outPath, file.replace("wav", "npy"))
112
-
113
- if os.path.exists(out_path):
114
- continue
115
-
116
- feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
117
- padding_mask = torch.BoolTensor(feats.shape).fill_(False)
118
- inputs = {
119
- "source": (
120
- feats.half().to(device)
121
- if is_half and device not in ["mps", "cpu"]
122
- else feats.to(device)
123
- ),
124
- "padding_mask": padding_mask.to(device),
125
- "output_layer": 9 if version == "v1" else 12, # layer 9
126
- }
127
- with torch.no_grad():
128
- logits = model.extract_features(**inputs)
129
- feats = (
130
- model.final_proj(logits[0]) if version == "v1" else logits[0]
131
- )
132
-
133
- feats = feats.squeeze(0).float().cpu().numpy()
134
- if np.isnan(feats).sum() == 0:
135
- np.save(out_path, feats, allow_pickle=False)
136
- else:
137
- printt("%s-contains nan" % file)
138
- if idx % n == 0:
139
- printt("now-%s,all-%s,%s,%s" % (len(todo), idx, file, feats.shape))
140
- except:
141
- printt(traceback.format_exc())
142
- printt("all-feature-done")
 
1
  import os
 
2
  import traceback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import fairseq
4
  import numpy as np
5
  import soundfile as sf
6
  import torch
7
  import torch.nn.functional as F
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ device = "cpu"
11
+ if torch.cuda.is_available():
12
+ device = "cuda"
13
+ elif torch.backends.mps.is_available():
14
+ device = "mps"
15
 
 
16
  model_path = "assets/hubert/hubert_base.pt"
17
+ models, saved_cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
18
+ [model_path],
19
+ suffix="",
 
 
20
  )
21
+ model = models[0]
22
+ model = model.to(device)
23
+ is_half = False
24
+ if is_half:
25
+ if device not in ["mps", "cpu"]:
26
+ model = model.half()
27
+ model.eval()
28
 
29
 
30
  # wave must be 16k, hop_size=320
 
42
  return feats
43
 
44
 
45
+ class HubertFeatureExtractor:
46
+ def __init__(self, exp_dir: str):
47
+ self.exp_dir = exp_dir
48
+ self.logfile = open("%s/extract_f0_feature.log" % exp_dir, "a+")
49
+ self.wavPath = "%s/1_16k_wavs" % exp_dir
50
+ self.outPath = "%s/3_feature768" % exp_dir
51
+ os.makedirs(self.outPath, exist_ok=True)
52
+
53
+ def println(self, strr):
54
+ print(strr)
55
+ self.logfile.write("%s\n" % strr)
56
+ self.logfile.flush()
57
+
58
+ def run(self):
59
+ todo = sorted(list(os.listdir(self.wavPath)))
60
+ n = max(1, len(todo) // 10) # 最多打印十条
61
+ if len(todo) == 0:
62
+ self.println("no-feature-todo")
63
+ else:
64
+ self.println("all-feature-%s" % len(todo))
65
+ for idx, file in enumerate(todo):
66
+ try:
67
+ if file.endswith(".wav"):
68
+ wav_path = "%s/%s" % (self.wavPath, file)
69
+ out_path = "%s/%s" % (self.outPath, file.replace("wav", "npy"))
70
+
71
+ if os.path.exists(out_path):
72
+ continue
73
+
74
+ feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
75
+ padding_mask = torch.BoolTensor(feats.shape).fill_(False)
76
+ inputs = {
77
+ "source": (
78
+ feats.half().to(device)
79
+ if is_half and device not in ["mps", "cpu"]
80
+ else feats.to(device)
81
+ ),
82
+ "padding_mask": padding_mask.to(device),
83
+ "output_layer": 12,
84
+ }
85
+ with torch.no_grad():
86
+ logits = model.extract_features(**inputs)
87
+ feats = logits[0]
88
+
89
+ feats = feats.squeeze(0).float().cpu().numpy()
90
+ if np.isnan(feats).sum() == 0:
91
+ np.save(out_path, feats, allow_pickle=False)
92
+ else:
93
+ self.println("%s-contains nan" % file)
94
+ if idx % n == 0:
95
+ self.println(
96
+ "now-%s,all-%s,%s,%s"
97
+ % (len(todo), idx, file, feats.shape)
98
+ )
99
+ except:
100
+ self.println(traceback.format_exc())
101
+ self.println("all-feature-done")