JacobLinCool commited on
Commit
4b56fbf
·
1 Parent(s): fa9dd69

feat: train index

Browse files
Files changed (2) hide show
  1. app.py +75 -2
  2. infer/modules/train/train.py +5 -5
app.py CHANGED
@@ -1,4 +1,8 @@
1
  import os
 
 
 
 
2
 
3
  os.environ["PYTORCH_JIT"] = "0v"
4
 
@@ -7,6 +11,7 @@ import gradio as gr
7
  import zipfile
8
  import tempfile
9
  import shutil
 
10
  from glob import glob
11
  from infer.modules.train.preprocess import PreProcess
12
  from infer.modules.train.extract.extract_f0_rmvpe import FeatureInput
@@ -193,6 +198,66 @@ def download_weight(exp_dir: str) -> str:
193
  return "assets/weights/%s.pth" % name
194
 
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  def download_expdir(exp_dir: str) -> str:
197
  shutil.make_archive(exp_dir, "zip", exp_dir)
198
  return f"{exp_dir}.zip"
@@ -206,7 +271,7 @@ def restore_expdir(zip: str) -> str:
206
 
207
  with gr.Blocks() as app:
208
  # allow user to manually select the experiment directory
209
- exp_dir = gr.Textbox(label="Experiment directory", visible=True, interactive=True)
210
 
211
  with gr.Tabs():
212
  with gr.Tab(label="New / Restore"):
@@ -244,8 +309,10 @@ with gr.Blocks() as app:
244
  with gr.Tab(label="Train"):
245
  with gr.Row():
246
  train_btn = gr.Button(value="Train", variant="primary")
247
- with gr.Row():
248
  latest_model = gr.File(label="Latest checkpoint")
 
 
 
249
 
250
  with gr.Tab(label="Download"):
251
  with gr.Row():
@@ -278,6 +345,12 @@ with gr.Blocks() as app:
278
  outputs=[latest_model],
279
  )
280
 
 
 
 
 
 
 
281
  download_weight_btn.click(
282
  fn=download_weight,
283
  inputs=[exp_dir],
 
1
  import os
2
+ import traceback
3
+
4
+ import numpy as np
5
+ from sklearn.cluster import MiniBatchKMeans
6
 
7
  os.environ["PYTORCH_JIT"] = "0v"
8
 
 
11
  import zipfile
12
  import tempfile
13
  import shutil
14
+ import faiss
15
  from glob import glob
16
  from infer.modules.train.preprocess import PreProcess
17
  from infer.modules.train.extract.extract_f0_rmvpe import FeatureInput
 
198
  return "assets/weights/%s.pth" % name
199
 
200
 
201
+ def train_index(exp_dir: str) -> str:
202
+ feature_dir = "%s/3_feature768" % (exp_dir)
203
+ if not os.path.exists(feature_dir):
204
+ raise gr.Error("Please extract features first.")
205
+ listdir_res = list(os.listdir(feature_dir))
206
+ if len(listdir_res) == 0:
207
+ raise gr.Error("Please extract features first.")
208
+ npys = []
209
+ for name in sorted(listdir_res):
210
+ phone = np.load("%s/%s" % (feature_dir, name))
211
+ npys.append(phone)
212
+ big_npy = np.concatenate(npys, 0)
213
+ big_npy_idx = np.arange(big_npy.shape[0])
214
+ np.random.shuffle(big_npy_idx)
215
+ big_npy = big_npy[big_npy_idx]
216
+ if big_npy.shape[0] > 2e5:
217
+ print("Trying doing kmeans %s shape to 10k centers." % big_npy.shape[0])
218
+ try:
219
+ big_npy = (
220
+ MiniBatchKMeans(
221
+ n_clusters=10000,
222
+ verbose=True,
223
+ batch_size=256 * 8,
224
+ compute_labels=False,
225
+ init="random",
226
+ )
227
+ .fit(big_npy)
228
+ .cluster_centers_
229
+ )
230
+ except:
231
+ info = traceback.format_exc()
232
+ print(info)
233
+ raise gr.Error(info)
234
+
235
+ np.save("%s/total_fea.npy" % exp_dir, big_npy)
236
+ n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
237
+ print("%s,%s" % (big_npy.shape, n_ivf))
238
+ index = faiss.index_factory(768, "IVF%s,Flat" % n_ivf)
239
+ # index = faiss.index_factory(256if version19=="v1"else 768, "IVF%s,PQ128x4fs,RFlat"%n_ivf)
240
+ print("training")
241
+ index_ivf = faiss.extract_index_ivf(index) #
242
+ index_ivf.nprobe = 1
243
+ index.train(big_npy)
244
+ faiss.write_index(
245
+ index,
246
+ "%s/trained_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe),
247
+ )
248
+ print("adding")
249
+ batch_size_add = 8192
250
+ for i in range(0, big_npy.shape[0], batch_size_add):
251
+ index.add(big_npy[i : i + batch_size_add])
252
+ faiss.write_index(
253
+ index,
254
+ "%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe),
255
+ )
256
+ print("built added_IVF%s_Flat_nprobe_%s.index" % (n_ivf, index_ivf.nprobe))
257
+
258
+ return "%s/added_IVF%s_Flat_nprobe_%s.index" % (exp_dir, n_ivf, index_ivf.nprobe)
259
+
260
+
261
  def download_expdir(exp_dir: str) -> str:
262
  shutil.make_archive(exp_dir, "zip", exp_dir)
263
  return f"{exp_dir}.zip"
 
271
 
272
  with gr.Blocks() as app:
273
  # allow user to manually select the experiment directory
274
+ exp_dir = gr.Textbox(label="Experiment directory (don't touch it unless you know what you are doing)", visible=True, interactive=True)
275
 
276
  with gr.Tabs():
277
  with gr.Tab(label="New / Restore"):
 
309
  with gr.Tab(label="Train"):
310
  with gr.Row():
311
  train_btn = gr.Button(value="Train", variant="primary")
 
312
  latest_model = gr.File(label="Latest checkpoint")
313
+ with gr.Row():
314
+ train_index_btn = gr.Button(value="Train index", variant="primary")
315
+ trained_index = gr.File(label="Trained index")
316
 
317
  with gr.Tab(label="Download"):
318
  with gr.Row():
 
345
  outputs=[latest_model],
346
  )
347
 
348
+ train_index_btn.click(
349
+ fn=train_index,
350
+ inputs=[exp_dir],
351
+ outputs=[trained_index],
352
+ )
353
+
354
  download_weight_btn.click(
355
  fn=download_weight,
356
  inputs=[exp_dir],
infer/modules/train/train.py CHANGED
@@ -200,8 +200,7 @@ def run(rank, n_gpus, hps, logger: logging.Logger, state):
200
  )
201
  state["global_step"] = (epoch_str - 1) * len(train_loader)
202
  print("loaded", epoch_str)
203
- # epoch_str = 1
204
- # global_step = 0
205
  except: # 如果首次不能加载,加载pretrain
206
  # traceback.print_exc()
207
  epoch_str = 1
@@ -248,7 +247,7 @@ def run(rank, n_gpus, hps, logger: logging.Logger, state):
248
  scaler = GradScaler(enabled=hps.train.fp16_run)
249
 
250
  cache = []
251
- trained = 0
252
  for epoch in range(epoch_str, hps.train.epochs + 1):
253
  if rank == 0:
254
  train_and_evaluate(
@@ -283,8 +282,9 @@ def run(rank, n_gpus, hps, logger: logging.Logger, state):
283
  scheduler_g.step()
284
  scheduler_d.step()
285
 
286
- trained += 1
287
- if trained >= 10:
 
288
  break
289
 
290
 
 
200
  )
201
  state["global_step"] = (epoch_str - 1) * len(train_loader)
202
  print("loaded", epoch_str)
203
+ epoch_str += 1
 
204
  except: # 如果首次不能加载,加载pretrain
205
  # traceback.print_exc()
206
  epoch_str = 1
 
247
  scaler = GradScaler(enabled=hps.train.fp16_run)
248
 
249
  cache = []
250
+ saved = 0
251
  for epoch in range(epoch_str, hps.train.epochs + 1):
252
  if rank == 0:
253
  train_and_evaluate(
 
282
  scheduler_g.step()
283
  scheduler_d.step()
284
 
285
+ if epoch % hps.save_every_epoch == 0 and rank == 0:
286
+ saved += 1
287
+ if saved >= 2:
288
  break
289
 
290