nateraw commited on
Commit
25ab393
1 Parent(s): b0324c8

Synced repo using 'sync_with_huggingface' Github Action

Browse files
Files changed (2) hide show
  1. app.py +58 -16
  2. training_so_vits_svc_fork.ipynb +540 -0
app.py CHANGED
@@ -14,12 +14,23 @@ from so_vits_svc_fork.hparams import HParams
14
  from so_vits_svc_fork.inference.core import Svc
15
 
16
 
17
- ##########################################################
18
- # REPLACE THESE VALUES TO CHANGE THE MODEL REPO/CKPT NAME
19
- ##########################################################
 
20
  repo_id = "dog/kanye"
 
21
  ckpt_name = None
22
- ##########################################################
 
 
 
 
 
 
 
 
 
23
 
24
  # Figure out the latest generator by taking highest value one.
25
  # Ex. if the repo has: G_0.pth, G_100.pth, G_200.pth, we'd use G_200.pth
@@ -33,12 +44,21 @@ if ckpt_name is None:
33
  )[-1]
34
  ckpt_name = f"G_{latest_id}.pth"
35
 
 
 
 
 
 
 
 
 
 
36
  generator_path = hf_hub_download(repo_id, ckpt_name)
37
  config_path = hf_hub_download(repo_id, "config.json")
38
  hparams = HParams(**json.loads(Path(config_path).read_text()))
39
  speakers = list(hparams.spk.keys())
40
  device = "cuda" if torch.cuda.is_available() else "cpu"
41
- model = Svc(net_g_path=generator_path, config_path=config_path, device=device, cluster_model_path=None)
42
  demucs_model = get_model(DEFAULT_MODEL)
43
 
44
 
@@ -133,7 +153,7 @@ def predict(
133
 
134
 
135
  def predict_song_from_yt(
136
- ytid,
137
  start,
138
  end,
139
  speaker=speakers[0],
@@ -147,7 +167,14 @@ def predict_song_from_yt(
147
  chunk_seconds: float = 0.5,
148
  absolute_thresh: bool = False,
149
  ):
150
- original_track_filepath = download_youtube_clip(ytid, start, end, "track.wav", force=True)
 
 
 
 
 
 
 
151
  vox_wav, inst_wav = extract_vocal_demucs(demucs_model, original_track_filepath)
152
  if transpose != 0:
153
  inst_wav = librosa.effects.pitch_shift(inst_wav.T, sr=model.target_sample, n_steps=transpose).T
@@ -187,9 +214,13 @@ interface_mic = gr.Interface(
187
  gr.Audio(type="filepath", source="microphone", label="Source Audio"),
188
  gr.Slider(-12, 12, value=0, step=1, label="Transpose (Semitones)"),
189
  gr.Checkbox(False, label="Auto Predict F0"),
190
- gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="cluster infer ratio"),
191
  gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="noise scale"),
192
- gr.Dropdown(choices=["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"], value="dio", label="f0 method"),
 
 
 
 
193
  ],
194
  outputs="audio",
195
  title="Voice Cloning",
@@ -203,9 +234,13 @@ interface_file = gr.Interface(
203
  gr.Audio(type="filepath", source="upload", label="Source Audio"),
204
  gr.Slider(-12, 12, value=0, step=1, label="Transpose (Semitones)"),
205
  gr.Checkbox(False, label="Auto Predict F0"),
206
- gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="cluster infer ratio"),
207
  gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="noise scale"),
208
- gr.Dropdown(choices=["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"], value="dio", label="f0 method"),
 
 
 
 
209
  ],
210
  outputs="audio",
211
  title="Voice Cloning",
@@ -215,23 +250,30 @@ interface_file = gr.Interface(
215
  interface_yt = gr.Interface(
216
  predict_song_from_yt,
217
  inputs=[
218
- "text",
 
 
219
  gr.Number(value=0, label="Start Time (seconds)"),
220
  gr.Number(value=15, label="End Time (seconds)"),
221
  gr.Dropdown(speakers, value=speakers[0], label="Target Speaker"),
222
  gr.Slider(-12, 12, value=0, step=1, label="Transpose (Semitones)"),
223
  gr.Checkbox(False, label="Auto Predict F0"),
224
- gr.Slider(0.0, 1.0, value=0.0, step=0.1, label="cluster infer ratio"),
225
  gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="noise scale"),
226
- gr.Dropdown(choices=["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"], value="dio", label="f0 method"),
 
 
 
 
227
  ],
228
  outputs=["audio", "audio"],
229
  title="Voice Cloning",
230
  description=description,
231
  article=article,
232
  examples=[
233
- ["COz9lDCFHjw", 75, 90, speakers[0], 0, False, 0.0, 0.4, "dio"],
234
- ["Wvm5GuDfAas", 15, 30, speakers[0], 0, False, 0.0, 0.4, "crepe"],
 
235
  ],
236
  )
237
  interface = gr.TabbedInterface(
 
14
  from so_vits_svc_fork.inference.core import Svc
15
 
16
 
17
+ ###################################################################
18
+ # REPLACE THESE VALUES TO CHANGE THE MODEL REPO/CKPT NAME/SETTINGS
19
+ ###################################################################
20
+ # The Hugging Face Hub repo ID
21
  repo_id = "dog/kanye"
22
+ # If None, Uses latest ckpt in the repo
23
  ckpt_name = None
24
+ # If None, Uses "kmeans.pt" if it exists in the repo
25
+ cluster_model_name = None
26
+ # Set the default f0 type to use - use the one it was trained on.
27
+ # The default for so-vits-svc-fork is "dio".
28
+ # Options: "crepe", "crepe-tiny", "parselmouth", "dio", "harvest"
29
+ default_f0_method = "crepe"
30
+ # The default ratio of cluster inference to SVC inference.
31
+ # If cluster_model_name is not found in the repo, this is set to 0.
32
+ default_cluster_infer_ratio = 0.5
33
+ ###################################################################
34
 
35
  # Figure out the latest generator by taking highest value one.
36
  # Ex. if the repo has: G_0.pth, G_100.pth, G_200.pth, we'd use G_200.pth
 
44
  )[-1]
45
  ckpt_name = f"G_{latest_id}.pth"
46
 
47
+ cluster_model_name = cluster_model_name or "kmeans.pt"
48
+ if cluster_model_name in list_repo_files(repo_id):
49
+ print(f"Found Cluster model - Downloading {cluster_model_name} from {repo_id}")
50
+ cluster_model_path = hf_hub_download(repo_id, cluster_model_name)
51
+ else:
52
+ print(f"Could not find {cluster_model_name} in {repo_id}. Using None")
53
+ cluster_model_path = None
54
+ default_cluster_infer_ratio = default_cluster_infer_ratio if cluster_model_path else 0
55
+
56
  generator_path = hf_hub_download(repo_id, ckpt_name)
57
  config_path = hf_hub_download(repo_id, "config.json")
58
  hparams = HParams(**json.loads(Path(config_path).read_text()))
59
  speakers = list(hparams.spk.keys())
60
  device = "cuda" if torch.cuda.is_available() else "cpu"
61
+ model = Svc(net_g_path=generator_path, config_path=config_path, device=device, cluster_model_path=cluster_model_path)
62
  demucs_model = get_model(DEFAULT_MODEL)
63
 
64
 
 
153
 
154
 
155
  def predict_song_from_yt(
156
+ ytid_or_url,
157
  start,
158
  end,
159
  speaker=speakers[0],
 
167
  chunk_seconds: float = 0.5,
168
  absolute_thresh: bool = False,
169
  ):
170
+ original_track_filepath = download_youtube_clip(
171
+ ytid_or_url,
172
+ start,
173
+ end,
174
+ "track.wav",
175
+ force=True,
176
+ url_base="" if ytid_or_url.startswith("http") else "https://www.youtube.com/watch?v=",
177
+ )
178
  vox_wav, inst_wav = extract_vocal_demucs(demucs_model, original_track_filepath)
179
  if transpose != 0:
180
  inst_wav = librosa.effects.pitch_shift(inst_wav.T, sr=model.target_sample, n_steps=transpose).T
 
214
  gr.Audio(type="filepath", source="microphone", label="Source Audio"),
215
  gr.Slider(-12, 12, value=0, step=1, label="Transpose (Semitones)"),
216
  gr.Checkbox(False, label="Auto Predict F0"),
217
+ gr.Slider(0.0, 1.0, value=default_cluster_infer_ratio, step=0.1, label="cluster infer ratio"),
218
  gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="noise scale"),
219
+ gr.Dropdown(
220
+ choices=["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
221
+ value=default_f0_method,
222
+ label="f0 method",
223
+ ),
224
  ],
225
  outputs="audio",
226
  title="Voice Cloning",
 
234
  gr.Audio(type="filepath", source="upload", label="Source Audio"),
235
  gr.Slider(-12, 12, value=0, step=1, label="Transpose (Semitones)"),
236
  gr.Checkbox(False, label="Auto Predict F0"),
237
+ gr.Slider(0.0, 1.0, value=default_cluster_infer_ratio, step=0.1, label="cluster infer ratio"),
238
  gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="noise scale"),
239
+ gr.Dropdown(
240
+ choices=["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
241
+ value=default_f0_method,
242
+ label="f0 method",
243
+ ),
244
  ],
245
  outputs="audio",
246
  title="Voice Cloning",
 
250
  interface_yt = gr.Interface(
251
  predict_song_from_yt,
252
  inputs=[
253
+ gr.Textbox(
254
+ label="YouTube URL or ID", info="A YouTube URL (or ID) to a song on YouTube you want to clone from"
255
+ ),
256
  gr.Number(value=0, label="Start Time (seconds)"),
257
  gr.Number(value=15, label="End Time (seconds)"),
258
  gr.Dropdown(speakers, value=speakers[0], label="Target Speaker"),
259
  gr.Slider(-12, 12, value=0, step=1, label="Transpose (Semitones)"),
260
  gr.Checkbox(False, label="Auto Predict F0"),
261
+ gr.Slider(0.0, 1.0, value=default_cluster_infer_ratio, step=0.1, label="cluster infer ratio"),
262
  gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="noise scale"),
263
+ gr.Dropdown(
264
+ choices=["crepe", "crepe-tiny", "parselmouth", "dio", "harvest"],
265
+ value=default_f0_method,
266
+ label="f0 method",
267
+ ),
268
  ],
269
  outputs=["audio", "audio"],
270
  title="Voice Cloning",
271
  description=description,
272
  article=article,
273
  examples=[
274
+ ["COz9lDCFHjw", 75, 90, speakers[0], 0, False, default_cluster_infer_ratio, 0.4, default_f0_method],
275
+ ["dQw4w9WgXcQ", 21, 35, speakers[0], 0, False, default_cluster_infer_ratio, 0.4, default_f0_method],
276
+ ["Wvm5GuDfAas", 15, 30, speakers[0], 0, False, default_cluster_infer_ratio, 0.4, default_f0_method],
277
  ],
278
  )
279
  interface = gr.TabbedInterface(
training_so_vits_svc_fork.ipynb ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {
6
+ "id": "view-in-github",
7
+ "colab_type": "text"
8
+ },
9
+ "source": [
10
+ "<a href=\"https://colab.research.google.com/github/nateraw/voice-cloning/blob/main/training_so_vits_svc_fork.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {
17
+ "id": "jIcNJ5QfDsV_"
18
+ },
19
+ "outputs": [],
20
+ "source": [
21
+ "# %%capture\n",
22
+ "! pip install git+https://github.com/nateraw/so-vits-svc-fork@main\n",
23
+ "! pip install openai-whisper yt-dlp huggingface_hub demucs"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "markdown",
28
+ "metadata": {
29
+ "id": "6uZAhUPOhFv9"
30
+ },
31
+ "source": [
32
+ "---\n",
33
+ "\n",
34
+ "# Restart runtime\n",
35
+ "\n",
36
+ "After running the cell above, you'll need to restart the Colab runtime because we installed a different version of numpy.\n",
37
+ "\n",
38
+ "`Runtime -> Restart runtime`\n",
39
+ "\n",
40
+ "---"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "metadata": {
47
+ "id": "DROusQatF-wF"
48
+ },
49
+ "outputs": [],
50
+ "source": [
51
+ "from huggingface_hub import login\n",
52
+ "\n",
53
+ "login()"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "markdown",
58
+ "source": [
59
+ "## Settings"
60
+ ],
61
+ "metadata": {
62
+ "id": "yOM9WWmmRqTA"
63
+ }
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {
69
+ "id": "5oTDjDEKFz3W"
70
+ },
71
+ "outputs": [],
72
+ "source": [
73
+ "CHARACTER = \"kanye\"\n",
74
+ "DO_EXTRACT_VOCALS = False\n",
75
+ "MODEL_REPO_ID = \"dog/kanye\""
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "markdown",
80
+ "metadata": {
81
+ "id": "BFd_ly1P_5Ht"
82
+ },
83
+ "source": [
84
+ "## Data Preparation\n",
85
+ "\n",
86
+ "Prepare a data.csv file here with `ytid,start,end` as the first line (they're the expected column names). Then, prepare a training set given YouTube IDs and their start and end segment times in seconds. Try to pick segments that have dry vocal only, as that'll provide the best results.\n",
87
+ "\n",
88
+ "An example is given below for Kanye West."
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {
95
+ "id": "rBrtgDtWmhRb"
96
+ },
97
+ "outputs": [],
98
+ "source": [
99
+ "%%writefile data.csv\n",
100
+ "ytid,start,end\n",
101
+ "lkK4de9nbzQ,0,137\n",
102
+ "gXU9Am2Seo0,30,69\n",
103
+ "gXU9Am2Seo0,94,135\n",
104
+ "iVgrhWvQpqU,0,55\n",
105
+ "iVgrhWvQpqU,58,110\n",
106
+ "UIV-q-gneKA,85,99\n",
107
+ "UIV-q-gneKA,110,125\n",
108
+ "UIV-q-gneKA,127,141\n",
109
+ "UIV-q-gneKA,173,183\n",
110
+ "GmlyYCGE9ak,0,102\n",
111
+ "x-7aWcPmJ60,25,43\n",
112
+ "x-7aWcPmJ60,47,72\n",
113
+ "x-7aWcPmJ60,98,113\n",
114
+ "DK2LCIzIBrU,0,56\n",
115
+ "DK2LCIzIBrU,80,166\n",
116
+ "_W56nZk0fCI,184,224"
117
+ ]
118
+ },
119
+ {
120
+ "cell_type": "code",
121
+ "execution_count": null,
122
+ "metadata": {
123
+ "id": "cxxp4uYoC0aG"
124
+ },
125
+ "outputs": [],
126
+ "source": [
127
+ "import subprocess\n",
128
+ "from pathlib import Path\n",
129
+ "import librosa\n",
130
+ "from scipy.io import wavfile\n",
131
+ "import numpy as np\n",
132
+ "from demucs.pretrained import get_model, DEFAULT_MODEL\n",
133
+ "from demucs.apply import apply_model\n",
134
+ "import torch\n",
135
+ "import csv\n",
136
+ "import whisper\n",
137
+ "\n",
138
+ "\n",
139
+ "def download_youtube_clip(video_identifier, start_time, end_time, output_filename, num_attempts=5, url_base=\"https://www.youtube.com/watch?v=\"):\n",
140
+ " status = False\n",
141
+ "\n",
142
+ " output_path = Path(output_filename)\n",
143
+ " if output_path.exists():\n",
144
+ " return True, \"Already Downloaded\"\n",
145
+ "\n",
146
+ " command = f\"\"\"\n",
147
+ " yt-dlp --quiet --no-warnings -x --audio-format wav -f bestaudio -o \"{output_filename}\" --download-sections \"*{start_time}-{end_time}\" \"{url_base}{video_identifier}\"\n",
148
+ " \"\"\".strip()\n",
149
+ "\n",
150
+ " attempts = 0\n",
151
+ " while True:\n",
152
+ " try:\n",
153
+ " output = subprocess.check_output(command, shell=True, stderr=subprocess.STDOUT)\n",
154
+ " except subprocess.CalledProcessError as err:\n",
155
+ " attempts += 1\n",
156
+ " if attempts == num_attempts:\n",
157
+ " return status, err.output\n",
158
+ " else:\n",
159
+ " break\n",
160
+ "\n",
161
+ " status = output_path.exists()\n",
162
+ " return status, \"Downloaded\"\n",
163
+ "\n",
164
+ "\n",
165
+ "def split_long_audio(model, filepaths, character_name, save_dir=\"data_dir\", out_sr=44100):\n",
166
+ " if isinstance(filepaths, str):\n",
167
+ " filepaths = [filepaths]\n",
168
+ "\n",
169
+ " for file_idx, filepath in enumerate(filepaths):\n",
170
+ "\n",
171
+ " save_path = Path(save_dir) / character_name\n",
172
+ " save_path.mkdir(exist_ok=True, parents=True)\n",
173
+ "\n",
174
+ " print(f\"Transcribing file {file_idx}: '{filepath}' to segments...\")\n",
175
+ " result = model.transcribe(filepath, word_timestamps=True, task=\"transcribe\", beam_size=5, best_of=5)\n",
176
+ " segments = result['segments']\n",
177
+ " \n",
178
+ " wav, sr = librosa.load(filepath, sr=None, offset=0, duration=None, mono=True)\n",
179
+ " wav, _ = librosa.effects.trim(wav, top_db=20)\n",
180
+ " peak = np.abs(wav).max()\n",
181
+ " if peak > 1.0:\n",
182
+ " wav = 0.98 * wav / peak\n",
183
+ " wav2 = librosa.resample(wav, orig_sr=sr, target_sr=out_sr)\n",
184
+ " wav2 /= max(wav2.max(), -wav2.min())\n",
185
+ "\n",
186
+ " for i, seg in enumerate(segments):\n",
187
+ " start_time = seg['start']\n",
188
+ " end_time = seg['end']\n",
189
+ " wav_seg = wav2[int(start_time * out_sr):int(end_time * out_sr)]\n",
190
+ " wav_seg_name = f\"{character_name}_{file_idx}_{i}.wav\"\n",
191
+ " out_fpath = save_path / wav_seg_name\n",
192
+ " wavfile.write(out_fpath, rate=out_sr, data=(wav_seg * np.iinfo(np.int16).max).astype(np.int16))\n",
193
+ "\n",
194
+ "\n",
195
+ "def extract_vocal_demucs(model, filename, out_filename, sr=44100, device=None, shifts=1, split=True, overlap=0.25, jobs=0):\n",
196
+ " wav, sr = librosa.load(filename, mono=False, sr=sr)\n",
197
+ " wav = torch.tensor(wav)\n",
198
+ " ref = wav.mean(0)\n",
199
+ " wav = (wav - ref.mean()) / ref.std()\n",
200
+ " sources = apply_model(\n",
201
+ " model,\n",
202
+ " wav[None],\n",
203
+ " device=device,\n",
204
+ " shifts=shifts,\n",
205
+ " split=split,\n",
206
+ " overlap=overlap,\n",
207
+ " progress=True,\n",
208
+ " num_workers=jobs\n",
209
+ " )[0]\n",
210
+ " sources = sources * ref.std() + ref.mean()\n",
211
+ "\n",
212
+ " wav = sources[-1]\n",
213
+ " wav = wav / max(1.01 * wav.abs().max(), 1)\n",
214
+ " wavfile.write(out_filename, rate=sr, data=wav.numpy().T)\n",
215
+ " return out_filename\n",
216
+ "\n",
217
+ "\n",
218
+ "def create_dataset(\n",
219
+ " clips_csv_filepath = \"data.csv\",\n",
220
+ " character = \"somebody\",\n",
221
+ " do_extract_vocals = False,\n",
222
+ " whisper_size = \"medium\",\n",
223
+ " # Where raw yt clips will be downloaded to\n",
224
+ " dl_dir = \"downloads\",\n",
225
+ " # Where actual data will be organized\n",
226
+ " data_dir = \"dataset_raw\",\n",
227
+ " **kwargs\n",
228
+ "):\n",
229
+ " dl_path = Path(dl_dir) / character\n",
230
+ " dl_path.mkdir(exist_ok=True, parents=True)\n",
231
+ " if do_extract_vocals:\n",
232
+ " demucs_model = get_model(DEFAULT_MODEL)\n",
233
+ "\n",
234
+ " with Path(clips_csv_filepath).open() as f:\n",
235
+ " reader = csv.DictReader(f)\n",
236
+ " for i, row in enumerate(reader):\n",
237
+ " outfile_path = dl_path / f\"{character}_{i:04d}.wav\"\n",
238
+ " download_youtube_clip(row['ytid'], row['start'], row['end'], outfile_path)\n",
239
+ " if do_extract_vocals:\n",
240
+ " extract_vocal_demucs(demucs_model, outfile_path, outfile_path)\n",
241
+ "\n",
242
+ " filenames = sorted([str(x) for x in dl_path.glob(\"*.wav\")])\n",
243
+ " whisper_model = whisper.load_model(whisper_size)\n",
244
+ " split_long_audio(whisper_model, filenames, character, data_dir) "
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": null,
250
+ "metadata": {
251
+ "id": "D9GrcDUKEGro"
252
+ },
253
+ "outputs": [],
254
+ "source": [
255
+ "\"\"\"\n",
256
+ "Here, we override config to have num_workers=0 because\n",
257
+ "of a limitation in HF Spaces Docker /dev/shm.\n",
258
+ "\"\"\"\n",
259
+ "\n",
260
+ "import json\n",
261
+ "from pathlib import Path\n",
262
+ "import multiprocessing\n",
263
+ "\n",
264
+ "def update_config(config_file=\"configs/44k/config.json\"):\n",
265
+ " config_path = Path(config_file)\n",
266
+ " data = json.loads(config_path.read_text())\n",
267
+ " data['train']['batch_size'] = 32\n",
268
+ " data['train']['eval_interval'] = 500\n",
269
+ " data['train']['num_workers'] = multiprocessing.cpu_count()\n",
270
+ " data['train']['persistent_workers'] = True\n",
271
+ " data['train']['push_to_hub'] = True\n",
272
+ " data['train']['repo_id'] = MODEL_REPO_ID # tuple(data['spk'])[0]\n",
273
+ " data['train']['private'] = True\n",
274
+ " config_path.write_text(json.dumps(data, indent=2, sort_keys=False))"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "markdown",
279
+ "source": [
280
+ "## Run all Preprocessing Steps"
281
+ ],
282
+ "metadata": {
283
+ "id": "aF6OZkTZRzhj"
284
+ }
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "execution_count": null,
289
+ "metadata": {
290
+ "id": "OAPnD3xKD_Gw"
291
+ },
292
+ "outputs": [],
293
+ "source": [
294
+ "create_dataset(character=CHARACTER, do_extract_vocals=DO_EXTRACT_VOCALS)\n",
295
+ "! svc pre-resample\n",
296
+ "! svc pre-config\n",
297
+ "! svc pre-hubert -fm crepe\n",
298
+ "update_config()"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "markdown",
303
+ "source": [
304
+ "## Training"
305
+ ],
306
+ "metadata": {
307
+ "id": "VpyGazF6R3CE"
308
+ }
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": null,
313
+ "metadata": {
314
+ "colab": {
315
+ "background_save": true
316
+ },
317
+ "id": "MByHpf_wEByg"
318
+ },
319
+ "outputs": [],
320
+ "source": [
321
+ "from __future__ import annotations\n",
322
+ "\n",
323
+ "import os\n",
324
+ "import re\n",
325
+ "import warnings\n",
326
+ "from logging import getLogger\n",
327
+ "from multiprocessing import cpu_count\n",
328
+ "from pathlib import Path\n",
329
+ "from typing import Any\n",
330
+ "\n",
331
+ "import lightning.pytorch as pl\n",
332
+ "import torch\n",
333
+ "from lightning.pytorch.accelerators import MPSAccelerator, TPUAccelerator\n",
334
+ "from lightning.pytorch.loggers import TensorBoardLogger\n",
335
+ "from lightning.pytorch.strategies.ddp import DDPStrategy\n",
336
+ "from lightning.pytorch.tuner import Tuner\n",
337
+ "from torch.cuda.amp import autocast\n",
338
+ "from torch.nn import functional as F\n",
339
+ "from torch.utils.data import DataLoader\n",
340
+ "from torch.utils.tensorboard.writer import SummaryWriter\n",
341
+ "\n",
342
+ "import so_vits_svc_fork.f0\n",
343
+ "import so_vits_svc_fork.modules.commons as commons\n",
344
+ "import so_vits_svc_fork.utils\n",
345
+ "\n",
346
+ "from so_vits_svc_fork import utils\n",
347
+ "from so_vits_svc_fork.dataset import TextAudioCollate, TextAudioDataset\n",
348
+ "from so_vits_svc_fork.logger import is_notebook\n",
349
+ "from so_vits_svc_fork.modules.descriminators import MultiPeriodDiscriminator\n",
350
+ "from so_vits_svc_fork.modules.losses import discriminator_loss, feature_loss, generator_loss, kl_loss\n",
351
+ "from so_vits_svc_fork.modules.mel_processing import mel_spectrogram_torch\n",
352
+ "from so_vits_svc_fork.modules.synthesizers import SynthesizerTrn\n",
353
+ "\n",
354
+ "from so_vits_svc_fork.train import VitsLightning, VCDataModule\n",
355
+ "\n",
356
+ "LOG = getLogger(__name__)\n",
357
+ "torch.set_float32_matmul_precision(\"high\")\n",
358
+ "\n",
359
+ "\n",
360
+ "from pathlib import Path\n",
361
+ "\n",
362
+ "from huggingface_hub import create_repo, upload_folder, login, list_repo_files, delete_file\n",
363
+ "\n",
364
+ "# if os.environ.get(\"HF_TOKEN\"):\n",
365
+ "# login(os.environ.get(\"HF_TOKEN\"))\n",
366
+ "\n",
367
+ "\n",
368
+ "class HuggingFacePushCallback(pl.Callback):\n",
369
+ " def __init__(self, repo_id, private=False, every=100):\n",
370
+ " self.repo_id = repo_id\n",
371
+ " self.private = private\n",
372
+ " self.every = every\n",
373
+ "\n",
374
+ " def on_validation_epoch_end(self, trainer, pl_module):\n",
375
+ " self.repo_url = create_repo(\n",
376
+ " repo_id=self.repo_id,\n",
377
+ " exist_ok=True,\n",
378
+ " private=self.private\n",
379
+ " )\n",
380
+ " self.repo_id = self.repo_url.repo_id\n",
381
+ " if pl_module.global_step == 0:\n",
382
+ " return\n",
383
+ " print(f\"\\n🤗 Pushing to Hugging Face Hub: {self.repo_url}...\")\n",
384
+ " model_dir = pl_module.hparams.model_dir\n",
385
+ " upload_folder(\n",
386
+ " repo_id=self.repo_id,\n",
387
+ " folder_path=model_dir,\n",
388
+ " path_in_repo=\".\",\n",
389
+ " commit_message=\"🍻 cheers\",\n",
390
+ " ignore_patterns=[\"*.git*\", \"*README.md*\", \"*__pycache__*\"],\n",
391
+ " )\n",
392
+ " ckpt_pattern = r'^(D_|G_)\\d+\\.pth$'\n",
393
+ " todelete = []\n",
394
+ " repo_ckpts = [x for x in list_repo_files(self.repo_id) if re.match(ckpt_pattern, x) and x not in [\"G_0.pth\", \"D_0.pth\"]]\n",
395
+ " local_ckpts = [x.name for x in Path(model_dir).glob(\"*.pth\") if re.match(ckpt_pattern, x.name)]\n",
396
+ " to_delete = set(repo_ckpts) - set(local_ckpts)\n",
397
+ "\n",
398
+ " for fname in to_delete:\n",
399
+ " print(f\"🗑 Deleting {fname} from repo\")\n",
400
+ " delete_file(fname, self.repo_id)\n",
401
+ "\n",
402
+ "\n",
403
+ "def train(\n",
404
+ " config_path: Path | str, model_path: Path | str, reset_optimizer: bool = False\n",
405
+ "):\n",
406
+ " config_path = Path(config_path)\n",
407
+ " model_path = Path(model_path)\n",
408
+ "\n",
409
+ " hparams = utils.get_backup_hparams(config_path, model_path)\n",
410
+ " utils.ensure_pretrained_model(model_path, hparams.model.get(\"type_\", \"hifi-gan\"))\n",
411
+ "\n",
412
+ " datamodule = VCDataModule(hparams)\n",
413
+ " strategy = (\n",
414
+ " (\n",
415
+ " \"ddp_find_unused_parameters_true\"\n",
416
+ " if os.name != \"nt\"\n",
417
+ " else DDPStrategy(find_unused_parameters=True, process_group_backend=\"gloo\")\n",
418
+ " )\n",
419
+ " if torch.cuda.device_count() > 1\n",
420
+ " else \"auto\"\n",
421
+ " )\n",
422
+ " LOG.info(f\"Using strategy: {strategy}\")\n",
423
+ " \n",
424
+ " callbacks = []\n",
425
+ " if hparams.train.push_to_hub:\n",
426
+ " callbacks.append(HuggingFacePushCallback(hparams.train.repo_id, hparams.train.private))\n",
427
+ " if not is_notebook():\n",
428
+ " callbacks.append(pl.callbacks.RichProgressBar())\n",
429
+ " if callbacks == []:\n",
430
+ " callbacks = None\n",
431
+ "\n",
432
+ " trainer = pl.Trainer(\n",
433
+ " logger=TensorBoardLogger(\n",
434
+ " model_path, \"lightning_logs\", hparams.train.get(\"log_version\", 0)\n",
435
+ " ),\n",
436
+ " # profiler=\"simple\",\n",
437
+ " val_check_interval=hparams.train.eval_interval,\n",
438
+ " max_epochs=hparams.train.epochs,\n",
439
+ " check_val_every_n_epoch=None,\n",
440
+ " precision=\"16-mixed\"\n",
441
+ " if hparams.train.fp16_run\n",
442
+ " else \"bf16-mixed\"\n",
443
+ " if hparams.train.get(\"bf16_run\", False)\n",
444
+ " else 32,\n",
445
+ " strategy=strategy,\n",
446
+ " callbacks=callbacks,\n",
447
+ " benchmark=True,\n",
448
+ " enable_checkpointing=False,\n",
449
+ " )\n",
450
+ " tuner = Tuner(trainer)\n",
451
+ " model = VitsLightning(reset_optimizer=reset_optimizer, **hparams)\n",
452
+ "\n",
453
+ " # automatic batch size scaling\n",
454
+ " batch_size = hparams.train.batch_size\n",
455
+ " batch_split = str(batch_size).split(\"-\")\n",
456
+ " batch_size = batch_split[0]\n",
457
+ " init_val = 2 if len(batch_split) <= 1 else int(batch_split[1])\n",
458
+ " max_trials = 25 if len(batch_split) <= 2 else int(batch_split[2])\n",
459
+ " if batch_size == \"auto\":\n",
460
+ " batch_size = \"binsearch\"\n",
461
+ " if batch_size in [\"power\", \"binsearch\"]:\n",
462
+ " model.tuning = True\n",
463
+ " tuner.scale_batch_size(\n",
464
+ " model,\n",
465
+ " mode=batch_size,\n",
466
+ " datamodule=datamodule,\n",
467
+ " steps_per_trial=1,\n",
468
+ " init_val=init_val,\n",
469
+ " max_trials=max_trials,\n",
470
+ " )\n",
471
+ " model.tuning = False\n",
472
+ " else:\n",
473
+ " batch_size = int(batch_size)\n",
474
+ " # automatic learning rate scaling is not supported for multiple optimizers\n",
475
+ " \"\"\"if hparams.train.learning_rate == \"auto\":\n",
476
+ " lr_finder = tuner.lr_find(model)\n",
477
+ " LOG.info(lr_finder.results)\n",
478
+ " fig = lr_finder.plot(suggest=True)\n",
479
+ " fig.savefig(model_path / \"lr_finder.png\")\"\"\"\n",
480
+ "\n",
481
+ " trainer.fit(model, datamodule=datamodule)\n",
482
+ "\n",
483
+ "if __name__ == '__main__':\n",
484
+ " train('configs/44k/config.json', 'logs/44k')"
485
+ ]
486
+ },
487
+ {
488
+ "cell_type": "markdown",
489
+ "source": [
490
+ "## Train Cluster Model"
491
+ ],
492
+ "metadata": {
493
+ "id": "b2vNCDrSR8Xo"
494
+ }
495
+ },
496
+ {
497
+ "cell_type": "code",
498
+ "execution_count": null,
499
+ "metadata": {
500
+ "id": "DBBEx-6Y1sOy"
501
+ },
502
+ "outputs": [],
503
+ "source": [
504
+ "! svc train-cluster"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": null,
510
+ "metadata": {
511
+ "id": "y_qYMuNY1tlm"
512
+ },
513
+ "outputs": [],
514
+ "source": [
515
+ "from huggingface_hub import upload_file\n",
516
+ "\n",
517
+ "upload_file(path_or_fileobj=\"/content/logs/44k/kmeans.pt\", repo_id=MODEL_REPO_ID, path_in_repo=\"kmeans.pt\")"
518
+ ]
519
+ }
520
+ ],
521
+ "metadata": {
522
+ "accelerator": "GPU",
523
+ "colab": {
524
+ "machine_shape": "hm",
525
+ "provenance": [],
526
+ "authorship_tag": "ABX9TyOQeFSvxop9rlCaglNlNoXI",
527
+ "include_colab_link": true
528
+ },
529
+ "gpuClass": "premium",
530
+ "kernelspec": {
531
+ "display_name": "Python 3",
532
+ "name": "python3"
533
+ },
534
+ "language_info": {
535
+ "name": "python"
536
+ }
537
+ },
538
+ "nbformat": 4,
539
+ "nbformat_minor": 0
540
+ }