Fix raw.arrow missing rows (#1145)
Browse files* fix raw.arrow missing rows
---------
Co-authored-by: SWivid <swivid@qq.com>
- src/f5_tts/train/datasets/prepare_csv_wavs.py +2 -2
- src/f5_tts/train/datasets/prepare_emilia.py +1 -0
- src/f5_tts/train/datasets/prepare_emilia_v2.py +1 -0
- src/f5_tts/train/datasets/prepare_libritts.py +1 -0
- src/f5_tts/train/datasets/prepare_ljspeech.py +1 -0
- src/f5_tts/train/finetune_gradio.py +2 -1
src/f5_tts/train/datasets/prepare_csv_wavs.py
CHANGED
|
@@ -208,11 +208,11 @@ def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_fine
|
|
| 208 |
out_dir.mkdir(exist_ok=True, parents=True)
|
| 209 |
print(f"\nSaving to {out_dir} ...")
|
| 210 |
|
| 211 |
-
# Save dataset with improved batch size for better I/O performance
|
| 212 |
raw_arrow_path = out_dir / "raw.arrow"
|
| 213 |
-
with ArrowWriter(path=raw_arrow_path.as_posix()
|
| 214 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
| 215 |
writer.write(line)
|
|
|
|
| 216 |
|
| 217 |
# Save durations to JSON
|
| 218 |
dur_json_path = out_dir / "duration.json"
|
|
|
|
| 208 |
out_dir.mkdir(exist_ok=True, parents=True)
|
| 209 |
print(f"\nSaving to {out_dir} ...")
|
| 210 |
|
|
|
|
| 211 |
raw_arrow_path = out_dir / "raw.arrow"
|
| 212 |
+
with ArrowWriter(path=raw_arrow_path.as_posix()) as writer:
|
| 213 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
| 214 |
writer.write(line)
|
| 215 |
+
writer.finalize()
|
| 216 |
|
| 217 |
# Save durations to JSON
|
| 218 |
dur_json_path = out_dir / "duration.json"
|
src/f5_tts/train/datasets/prepare_emilia.py
CHANGED
|
@@ -181,6 +181,7 @@ def main():
|
|
| 181 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
| 182 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
| 183 |
writer.write(line)
|
|
|
|
| 184 |
|
| 185 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
| 186 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
|
|
|
| 181 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
| 182 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
| 183 |
writer.write(line)
|
| 184 |
+
writer.finalize()
|
| 185 |
|
| 186 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
| 187 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
src/f5_tts/train/datasets/prepare_emilia_v2.py
CHANGED
|
@@ -68,6 +68,7 @@ def main():
|
|
| 68 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
| 69 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
| 70 |
writer.write(line)
|
|
|
|
| 71 |
|
| 72 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
| 73 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
|
|
|
| 68 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
| 69 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
| 70 |
writer.write(line)
|
| 71 |
+
writer.finalize()
|
| 72 |
|
| 73 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
| 74 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
src/f5_tts/train/datasets/prepare_libritts.py
CHANGED
|
@@ -62,6 +62,7 @@ def main():
|
|
| 62 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
| 63 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
| 64 |
writer.write(line)
|
|
|
|
| 65 |
|
| 66 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
| 67 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
|
|
|
| 62 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
| 63 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
| 64 |
writer.write(line)
|
| 65 |
+
writer.finalize()
|
| 66 |
|
| 67 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
| 68 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
src/f5_tts/train/datasets/prepare_ljspeech.py
CHANGED
|
@@ -39,6 +39,7 @@ def main():
|
|
| 39 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
| 40 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
| 41 |
writer.write(line)
|
|
|
|
| 42 |
|
| 43 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
| 44 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
|
|
|
| 39 |
with ArrowWriter(path=f"{save_dir}/raw.arrow") as writer:
|
| 40 |
for line in tqdm(result, desc="Writing to raw.arrow ..."):
|
| 41 |
writer.write(line)
|
| 42 |
+
writer.finalize()
|
| 43 |
|
| 44 |
# dup a json separately saving duration in case for DynamicBatchSampler ease
|
| 45 |
with open(f"{save_dir}/duration.json", "w", encoding="utf-8") as f:
|
src/f5_tts/train/finetune_gradio.py
CHANGED
|
@@ -796,9 +796,10 @@ def create_metadata(name_project, ch_tokenizer, progress=gr.Progress()):
|
|
| 796 |
min_second = round(min(duration_list), 2)
|
| 797 |
max_second = round(max(duration_list), 2)
|
| 798 |
|
| 799 |
-
with ArrowWriter(path=file_raw
|
| 800 |
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
|
| 801 |
writer.write(line)
|
|
|
|
| 802 |
|
| 803 |
with open(file_duration, "w") as f:
|
| 804 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|
|
|
|
| 796 |
min_second = round(min(duration_list), 2)
|
| 797 |
max_second = round(max(duration_list), 2)
|
| 798 |
|
| 799 |
+
with ArrowWriter(path=file_raw) as writer:
|
| 800 |
for line in progress.tqdm(result, total=len(result), desc="prepare data"):
|
| 801 |
writer.write(line)
|
| 802 |
+
writer.finalize()
|
| 803 |
|
| 804 |
with open(file_duration, "w") as f:
|
| 805 |
json.dump({"duration": duration_list}, f, ensure_ascii=False)
|