Spaces:
				
			
			
	
			
			
		Configuration error
		
	
	
	
			
			
	
	
	
	
		
		
		Configuration error
		
	make a structure first
Browse files- src/f5_tts/api.py +5 -6
- src/f5_tts/{scripts β eval}/eval_infer_batch.py +0 -0
- src/f5_tts/{scripts β eval}/eval_infer_batch.sh +0 -0
- src/f5_tts/{scripts β eval}/eval_librispeech_test_clean.py +0 -0
- src/f5_tts/{scripts β eval}/eval_seedtts_testset.py +0 -0
- src/f5_tts/{data β eval/eval_testset}/librispeech_pc_test_clean_cross_sentence.lst +0 -0
- src/f5_tts/{data/inference-cli.toml β infer/examples/basic/basic.toml} +0 -0
- src/f5_tts/infer/examples/basic/basic_ref_en.wav +0 -0
- src/f5_tts/infer/examples/basic/basic_ref_zh.wav +0 -0
- {samples β src/f5_tts/infer/examples/multi}/country.flac +0 -0
- {samples β src/f5_tts/infer/examples/multi}/main.flac +0 -0
- {samples β src/f5_tts/infer/examples/multi}/story.toml +0 -0
- {samples β src/f5_tts/infer/examples/multi}/story.txt +0 -0
- {samples β src/f5_tts/infer/examples/multi}/town.flac +0 -0
- src/f5_tts/{data/Emilia_ZH_EN_pinyin β infer/examples}/vocab.txt +0 -0
- src/f5_tts/{inference_cli.py β infer/infer_cli.py} +1 -1
- src/f5_tts/{gradio_app.py β infer/infer_gradio.py} +0 -0
- src/f5_tts/{speech_edit.py β infer/speech_edit.py} +0 -0
- src/f5_tts/scripts/count_params_gflops.py +2 -2
- src/f5_tts/{finetune_cli.py β train/finetune_cli.py} +128 -128
- src/f5_tts/{finetune_gradio.py β train/finetune_gradio.py} +944 -944
- src/f5_tts/{train.py β train/train.py} +0 -0
    	
        src/f5_tts/api.py
    CHANGED
    
    | @@ -1,15 +1,14 @@ | |
|  | |
|  | |
|  | |
|  | |
| 1 | 
             
            import soundfile as sf
         | 
| 2 | 
             
            import torch
         | 
| 3 | 
            -
            import tqdm
         | 
| 4 | 
             
            from cached_path import cached_path
         | 
| 5 |  | 
| 6 | 
             
            from f5_tts.model import DiT, UNetT
         | 
| 7 | 
            -
            from f5_tts.model.utils import save_spectrogram
         | 
| 8 | 
            -
             | 
| 9 | 
             
            from f5_tts.model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
         | 
| 10 | 
            -
            from f5_tts.model.utils import seed_everything
         | 
| 11 | 
            -
            import random
         | 
| 12 | 
            -
            import sys
         | 
| 13 |  | 
| 14 |  | 
| 15 | 
             
            class F5TTS:
         | 
|  | |
| 1 | 
            +
            import random
         | 
| 2 | 
            +
            import sys
         | 
| 3 | 
            +
            import tqdm
         | 
| 4 | 
            +
             | 
| 5 | 
             
            import soundfile as sf
         | 
| 6 | 
             
            import torch
         | 
|  | |
| 7 | 
             
            from cached_path import cached_path
         | 
| 8 |  | 
| 9 | 
             
            from f5_tts.model import DiT, UNetT
         | 
| 10 | 
            +
            from f5_tts.model.utils import seed_everything, save_spectrogram
         | 
|  | |
| 11 | 
             
            from f5_tts.model.utils_infer import load_vocoder, load_model, infer_process, remove_silence_for_generated_wav
         | 
|  | |
|  | |
|  | |
| 12 |  | 
| 13 |  | 
| 14 | 
             
            class F5TTS:
         | 
    	
        src/f5_tts/{scripts β eval}/eval_infer_batch.py
    RENAMED
    
    | 
            File without changes
         | 
    	
        src/f5_tts/{scripts β eval}/eval_infer_batch.sh
    RENAMED
    
    | 
            File without changes
         | 
    	
        src/f5_tts/{scripts β eval}/eval_librispeech_test_clean.py
    RENAMED
    
    | 
            File without changes
         | 
    	
        src/f5_tts/{scripts β eval}/eval_seedtts_testset.py
    RENAMED
    
    | 
            File without changes
         | 
    	
        src/f5_tts/{data β eval/eval_testset}/librispeech_pc_test_clean_cross_sentence.lst
    RENAMED
    
    | 
            File without changes
         | 
    	
        src/f5_tts/{data/inference-cli.toml β infer/examples/basic/basic.toml}
    RENAMED
    
    | 
            File without changes
         | 
    	
        src/f5_tts/infer/examples/basic/basic_ref_en.wav
    ADDED
    
    | Binary file (256 kB). View file | 
|  | 
    	
        src/f5_tts/infer/examples/basic/basic_ref_zh.wav
    ADDED
    
    | Binary file (325 kB). View file | 
|  | 
    	
        {samples β src/f5_tts/infer/examples/multi}/country.flac
    RENAMED
    
    | 
            File without changes
         | 
    	
        {samples β src/f5_tts/infer/examples/multi}/main.flac
    RENAMED
    
    | 
            File without changes
         | 
    	
        {samples β src/f5_tts/infer/examples/multi}/story.toml
    RENAMED
    
    | 
            File without changes
         | 
    	
        {samples β src/f5_tts/infer/examples/multi}/story.txt
    RENAMED
    
    | 
            File without changes
         | 
    	
        {samples β src/f5_tts/infer/examples/multi}/town.flac
    RENAMED
    
    | 
            File without changes
         | 
    	
        src/f5_tts/{data/Emilia_ZH_EN_pinyin β infer/examples}/vocab.txt
    RENAMED
    
    | 
            File without changes
         | 
    	
        src/f5_tts/{inference_cli.py β infer/infer_cli.py}
    RENAMED
    
    | @@ -1,7 +1,7 @@ | |
| 1 | 
             
            import argparse
         | 
| 2 | 
             
            import codecs
         | 
| 3 | 
            -
            import re
         | 
| 4 | 
             
            import os
         | 
|  | |
| 5 | 
             
            from pathlib import Path
         | 
| 6 | 
             
            from importlib.resources import files
         | 
| 7 |  | 
|  | |
| 1 | 
             
            import argparse
         | 
| 2 | 
             
            import codecs
         | 
|  | |
| 3 | 
             
            import os
         | 
| 4 | 
            +
            import re
         | 
| 5 | 
             
            from pathlib import Path
         | 
| 6 | 
             
            from importlib.resources import files
         | 
| 7 |  | 
    	
        src/f5_tts/{gradio_app.py β infer/infer_gradio.py}
    RENAMED
    
    | 
            File without changes
         | 
    	
        src/f5_tts/{speech_edit.py β infer/speech_edit.py}
    RENAMED
    
    | 
            File without changes
         | 
    	
        src/f5_tts/scripts/count_params_gflops.py
    CHANGED
    
    | @@ -3,7 +3,7 @@ import os | |
| 3 |  | 
| 4 | 
             
            sys.path.append(os.getcwd())
         | 
| 5 |  | 
| 6 | 
            -
            from f5_tts.model import  | 
| 7 |  | 
| 8 | 
             
            import torch
         | 
| 9 | 
             
            import thop
         | 
| @@ -24,7 +24,7 @@ import thop | |
| 24 | 
             
            transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
         | 
| 25 |  | 
| 26 |  | 
| 27 | 
            -
            model =  | 
| 28 | 
             
            target_sample_rate = 24000
         | 
| 29 | 
             
            n_mel_channels = 100
         | 
| 30 | 
             
            hop_length = 256
         | 
|  | |
| 3 |  | 
| 4 | 
             
            sys.path.append(os.getcwd())
         | 
| 5 |  | 
| 6 | 
            +
            from f5_tts.model import CFM, DiT
         | 
| 7 |  | 
| 8 | 
             
            import torch
         | 
| 9 | 
             
            import thop
         | 
|  | |
| 24 | 
             
            transformer = DiT(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
         | 
| 25 |  | 
| 26 |  | 
| 27 | 
            +
            model = CFM(transformer=transformer)
         | 
| 28 | 
             
            target_sample_rate = 24000
         | 
| 29 | 
             
            n_mel_channels = 100
         | 
| 30 | 
             
            hop_length = 256
         | 
    	
        src/f5_tts/{finetune_cli.py β train/finetune_cli.py}
    RENAMED
    
    | @@ -1,128 +1,128 @@ | |
| 1 | 
            -
            import argparse
         | 
| 2 | 
            -
            import os
         | 
| 3 | 
            -
            import shutil
         | 
| 4 | 
            -
             | 
| 5 | 
            -
            from cached_path import cached_path
         | 
| 6 | 
            -
            from f5_tts.model import CFM, UNetT, DiT, Trainer
         | 
| 7 | 
            -
            from f5_tts.model.utils import get_tokenizer
         | 
| 8 | 
            -
            from f5_tts.model.dataset import load_dataset
         | 
| 9 | 
            -
             | 
| 10 | 
            -
            # -------------------------- Dataset Settings --------------------------- #
         | 
| 11 | 
            -
            target_sample_rate = 24000
         | 
| 12 | 
            -
            n_mel_channels = 100
         | 
| 13 | 
            -
            hop_length = 256
         | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
            # -------------------------- Argument Parsing --------------------------- #
         | 
| 17 | 
            -
            def parse_args():
         | 
| 18 | 
            -
                parser = argparse.ArgumentParser(description="Train CFM Model")
         | 
| 19 | 
            -
             | 
| 20 | 
            -
                parser.add_argument(
         | 
| 21 | 
            -
                    "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
         | 
| 22 | 
            -
                )
         | 
| 23 | 
            -
                parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
         | 
| 24 | 
            -
                parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training")
         | 
| 25 | 
            -
                parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU")
         | 
| 26 | 
            -
                parser.add_argument(
         | 
| 27 | 
            -
                    "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
         | 
| 28 | 
            -
                )
         | 
| 29 | 
            -
                parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch")
         | 
| 30 | 
            -
                parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
         | 
| 31 | 
            -
                parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
         | 
| 32 | 
            -
                parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
         | 
| 33 | 
            -
                parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps")
         | 
| 34 | 
            -
                parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps")
         | 
| 35 | 
            -
                parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps")
         | 
| 36 | 
            -
                parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
         | 
| 37 | 
            -
             | 
| 38 | 
            -
                parser.add_argument(
         | 
| 39 | 
            -
                    "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
         | 
| 40 | 
            -
                )
         | 
| 41 | 
            -
                parser.add_argument(
         | 
| 42 | 
            -
                    "--tokenizer_path",
         | 
| 43 | 
            -
                    type=str,
         | 
| 44 | 
            -
                    default=None,
         | 
| 45 | 
            -
                    help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
         | 
| 46 | 
            -
                )
         | 
| 47 | 
            -
             | 
| 48 | 
            -
                return parser.parse_args()
         | 
| 49 | 
            -
             | 
| 50 | 
            -
             | 
| 51 | 
            -
            # -------------------------- Training Settings -------------------------- #
         | 
| 52 | 
            -
             | 
| 53 | 
            -
             | 
| 54 | 
            -
            def main():
         | 
| 55 | 
            -
                args = parse_args()
         | 
| 56 | 
            -
             | 
| 57 | 
            -
                # Model parameters based on experiment name
         | 
| 58 | 
            -
                if args.exp_name == "F5TTS_Base":
         | 
| 59 | 
            -
                    wandb_resume_id = None
         | 
| 60 | 
            -
                    model_cls = DiT
         | 
| 61 | 
            -
                    model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
         | 
| 62 | 
            -
                    if args.finetune:
         | 
| 63 | 
            -
                        ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
         | 
| 64 | 
            -
                elif args.exp_name == "E2TTS_Base":
         | 
| 65 | 
            -
                    wandb_resume_id = None
         | 
| 66 | 
            -
                    model_cls = UNetT
         | 
| 67 | 
            -
                    model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
         | 
| 68 | 
            -
                    if args.finetune:
         | 
| 69 | 
            -
                        ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
         | 
| 70 | 
            -
             | 
| 71 | 
            -
                if args.finetune:
         | 
| 72 | 
            -
                    path_ckpt = os.path.join("ckpts", args.dataset_name)
         | 
| 73 | 
            -
                    if not os.path.isdir(path_ckpt):
         | 
| 74 | 
            -
                        os.makedirs(path_ckpt, exist_ok=True)
         | 
| 75 | 
            -
                        shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path)))
         | 
| 76 | 
            -
             | 
| 77 | 
            -
                checkpoint_path = os.path.join("ckpts", args.dataset_name)
         | 
| 78 | 
            -
             | 
| 79 | 
            -
                # Use the tokenizer and tokenizer_path provided in the command line arguments
         | 
| 80 | 
            -
                tokenizer = args.tokenizer
         | 
| 81 | 
            -
                if tokenizer == "custom":
         | 
| 82 | 
            -
                    if not args.tokenizer_path:
         | 
| 83 | 
            -
                        raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
         | 
| 84 | 
            -
                    tokenizer_path = args.tokenizer_path
         | 
| 85 | 
            -
                else:
         | 
| 86 | 
            -
                    tokenizer_path = args.dataset_name
         | 
| 87 | 
            -
             | 
| 88 | 
            -
                vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
         | 
| 89 | 
            -
             | 
| 90 | 
            -
                mel_spec_kwargs = dict(
         | 
| 91 | 
            -
                    target_sample_rate=target_sample_rate,
         | 
| 92 | 
            -
                    n_mel_channels=n_mel_channels,
         | 
| 93 | 
            -
                    hop_length=hop_length,
         | 
| 94 | 
            -
                )
         | 
| 95 | 
            -
             | 
| 96 | 
            -
                e2tts = CFM(
         | 
| 97 | 
            -
                    transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
         | 
| 98 | 
            -
                    mel_spec_kwargs=mel_spec_kwargs,
         | 
| 99 | 
            -
                    vocab_char_map=vocab_char_map,
         | 
| 100 | 
            -
                )
         | 
| 101 | 
            -
             | 
| 102 | 
            -
                trainer = Trainer(
         | 
| 103 | 
            -
                    e2tts,
         | 
| 104 | 
            -
                    args.epochs,
         | 
| 105 | 
            -
                    args.learning_rate,
         | 
| 106 | 
            -
                    num_warmup_updates=args.num_warmup_updates,
         | 
| 107 | 
            -
                    save_per_updates=args.save_per_updates,
         | 
| 108 | 
            -
                    checkpoint_path=checkpoint_path,
         | 
| 109 | 
            -
                    batch_size=args.batch_size_per_gpu,
         | 
| 110 | 
            -
                    batch_size_type=args.batch_size_type,
         | 
| 111 | 
            -
                    max_samples=args.max_samples,
         | 
| 112 | 
            -
                    grad_accumulation_steps=args.grad_accumulation_steps,
         | 
| 113 | 
            -
                    max_grad_norm=args.max_grad_norm,
         | 
| 114 | 
            -
                    wandb_project="CFM-TTS",
         | 
| 115 | 
            -
                    wandb_run_name=args.exp_name,
         | 
| 116 | 
            -
                    wandb_resume_id=wandb_resume_id,
         | 
| 117 | 
            -
                    last_per_steps=args.last_per_steps,
         | 
| 118 | 
            -
                )
         | 
| 119 | 
            -
             | 
| 120 | 
            -
                train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
         | 
| 121 | 
            -
                trainer.train(
         | 
| 122 | 
            -
                    train_dataset,
         | 
| 123 | 
            -
                    resumable_with_seed=666,  # seed for shuffling dataset
         | 
| 124 | 
            -
                )
         | 
| 125 | 
            -
             | 
| 126 | 
            -
             | 
| 127 | 
            -
            if __name__ == "__main__":
         | 
| 128 | 
            -
                main()
         | 
|  | |
| 1 | 
            +
            import argparse
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import shutil
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from cached_path import cached_path
         | 
| 6 | 
            +
            from f5_tts.model import CFM, UNetT, DiT, Trainer
         | 
| 7 | 
            +
            from f5_tts.model.utils import get_tokenizer
         | 
| 8 | 
            +
            from f5_tts.model.dataset import load_dataset
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            # -------------------------- Dataset Settings --------------------------- #
         | 
| 11 | 
            +
            target_sample_rate = 24000
         | 
| 12 | 
            +
            n_mel_channels = 100
         | 
| 13 | 
            +
            hop_length = 256
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            # -------------------------- Argument Parsing --------------------------- #
         | 
| 17 | 
            +
            def parse_args():
         | 
| 18 | 
            +
                parser = argparse.ArgumentParser(description="Train CFM Model")
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                parser.add_argument(
         | 
| 21 | 
            +
                    "--exp_name", type=str, default="F5TTS_Base", choices=["F5TTS_Base", "E2TTS_Base"], help="Experiment name"
         | 
| 22 | 
            +
                )
         | 
| 23 | 
            +
                parser.add_argument("--dataset_name", type=str, default="Emilia_ZH_EN", help="Name of the dataset to use")
         | 
| 24 | 
            +
                parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate for training")
         | 
| 25 | 
            +
                parser.add_argument("--batch_size_per_gpu", type=int, default=256, help="Batch size per GPU")
         | 
| 26 | 
            +
                parser.add_argument(
         | 
| 27 | 
            +
                    "--batch_size_type", type=str, default="frame", choices=["frame", "sample"], help="Batch size type"
         | 
| 28 | 
            +
                )
         | 
| 29 | 
            +
                parser.add_argument("--max_samples", type=int, default=16, help="Max sequences per batch")
         | 
| 30 | 
            +
                parser.add_argument("--grad_accumulation_steps", type=int, default=1, help="Gradient accumulation steps")
         | 
| 31 | 
            +
                parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Max gradient norm for clipping")
         | 
| 32 | 
            +
                parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs")
         | 
| 33 | 
            +
                parser.add_argument("--num_warmup_updates", type=int, default=5, help="Warmup steps")
         | 
| 34 | 
            +
                parser.add_argument("--save_per_updates", type=int, default=10, help="Save checkpoint every X steps")
         | 
| 35 | 
            +
                parser.add_argument("--last_per_steps", type=int, default=10, help="Save last checkpoint every X steps")
         | 
| 36 | 
            +
                parser.add_argument("--finetune", type=bool, default=True, help="Use Finetune")
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                parser.add_argument(
         | 
| 39 | 
            +
                    "--tokenizer", type=str, default="pinyin", choices=["pinyin", "char", "custom"], help="Tokenizer type"
         | 
| 40 | 
            +
                )
         | 
| 41 | 
            +
                parser.add_argument(
         | 
| 42 | 
            +
                    "--tokenizer_path",
         | 
| 43 | 
            +
                    type=str,
         | 
| 44 | 
            +
                    default=None,
         | 
| 45 | 
            +
                    help="Path to custom tokenizer vocab file (only used if tokenizer = 'custom')",
         | 
| 46 | 
            +
                )
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                return parser.parse_args()
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            # -------------------------- Training Settings -------------------------- #
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            def main():
         | 
| 55 | 
            +
                args = parse_args()
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                # Model parameters based on experiment name
         | 
| 58 | 
            +
                if args.exp_name == "F5TTS_Base":
         | 
| 59 | 
            +
                    wandb_resume_id = None
         | 
| 60 | 
            +
                    model_cls = DiT
         | 
| 61 | 
            +
                    model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
         | 
| 62 | 
            +
                    if args.finetune:
         | 
| 63 | 
            +
                        ckpt_path = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.pt"))
         | 
| 64 | 
            +
                elif args.exp_name == "E2TTS_Base":
         | 
| 65 | 
            +
                    wandb_resume_id = None
         | 
| 66 | 
            +
                    model_cls = UNetT
         | 
| 67 | 
            +
                    model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
         | 
| 68 | 
            +
                    if args.finetune:
         | 
| 69 | 
            +
                        ckpt_path = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.pt"))
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                if args.finetune:
         | 
| 72 | 
            +
                    path_ckpt = os.path.join("ckpts", args.dataset_name)
         | 
| 73 | 
            +
                    if not os.path.isdir(path_ckpt):
         | 
| 74 | 
            +
                        os.makedirs(path_ckpt, exist_ok=True)
         | 
| 75 | 
            +
                        shutil.copy2(ckpt_path, os.path.join(path_ckpt, os.path.basename(ckpt_path)))
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                checkpoint_path = os.path.join("ckpts", args.dataset_name)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                # Use the tokenizer and tokenizer_path provided in the command line arguments
         | 
| 80 | 
            +
                tokenizer = args.tokenizer
         | 
| 81 | 
            +
                if tokenizer == "custom":
         | 
| 82 | 
            +
                    if not args.tokenizer_path:
         | 
| 83 | 
            +
                        raise ValueError("Custom tokenizer selected, but no tokenizer_path provided.")
         | 
| 84 | 
            +
                    tokenizer_path = args.tokenizer_path
         | 
| 85 | 
            +
                else:
         | 
| 86 | 
            +
                    tokenizer_path = args.dataset_name
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                vocab_char_map, vocab_size = get_tokenizer(tokenizer_path, tokenizer)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                mel_spec_kwargs = dict(
         | 
| 91 | 
            +
                    target_sample_rate=target_sample_rate,
         | 
| 92 | 
            +
                    n_mel_channels=n_mel_channels,
         | 
| 93 | 
            +
                    hop_length=hop_length,
         | 
| 94 | 
            +
                )
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                e2tts = CFM(
         | 
| 97 | 
            +
                    transformer=model_cls(**model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels),
         | 
| 98 | 
            +
                    mel_spec_kwargs=mel_spec_kwargs,
         | 
| 99 | 
            +
                    vocab_char_map=vocab_char_map,
         | 
| 100 | 
            +
                )
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                trainer = Trainer(
         | 
| 103 | 
            +
                    e2tts,
         | 
| 104 | 
            +
                    args.epochs,
         | 
| 105 | 
            +
                    args.learning_rate,
         | 
| 106 | 
            +
                    num_warmup_updates=args.num_warmup_updates,
         | 
| 107 | 
            +
                    save_per_updates=args.save_per_updates,
         | 
| 108 | 
            +
                    checkpoint_path=checkpoint_path,
         | 
| 109 | 
            +
                    batch_size=args.batch_size_per_gpu,
         | 
| 110 | 
            +
                    batch_size_type=args.batch_size_type,
         | 
| 111 | 
            +
                    max_samples=args.max_samples,
         | 
| 112 | 
            +
                    grad_accumulation_steps=args.grad_accumulation_steps,
         | 
| 113 | 
            +
                    max_grad_norm=args.max_grad_norm,
         | 
| 114 | 
            +
                    wandb_project="CFM-TTS",
         | 
| 115 | 
            +
                    wandb_run_name=args.exp_name,
         | 
| 116 | 
            +
                    wandb_resume_id=wandb_resume_id,
         | 
| 117 | 
            +
                    last_per_steps=args.last_per_steps,
         | 
| 118 | 
            +
                )
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                train_dataset = load_dataset(args.dataset_name, tokenizer, mel_spec_kwargs=mel_spec_kwargs)
         | 
| 121 | 
            +
                trainer.train(
         | 
| 122 | 
            +
                    train_dataset,
         | 
| 123 | 
            +
                    resumable_with_seed=666,  # seed for shuffling dataset
         | 
| 124 | 
            +
                )
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
            if __name__ == "__main__":
         | 
| 128 | 
            +
                main()
         | 
    	
        src/f5_tts/{finetune_gradio.py β train/finetune_gradio.py}
    RENAMED
    
    | @@ -1,944 +1,944 @@ | |
| 1 | 
            -
            import  | 
| 2 | 
            -
            import  | 
| 3 | 
            -
             | 
| 4 | 
            -
            import  | 
| 5 | 
            -
            import  | 
| 6 | 
            -
             | 
| 7 | 
            -
            import  | 
| 8 | 
            -
            import  | 
| 9 | 
            -
            import  | 
| 10 | 
            -
            import  | 
| 11 | 
            -
            import  | 
| 12 | 
            -
             | 
| 13 | 
            -
            import  | 
| 14 | 
            -
             | 
| 15 | 
            -
             | 
| 16 | 
            -
            import  | 
| 17 | 
            -
            import  | 
| 18 | 
            -
             | 
| 19 | 
            -
            import  | 
| 20 | 
            -
             | 
| 21 | 
            -
            import  | 
| 22 | 
            -
            import  | 
| 23 | 
            -
            import  | 
| 24 | 
            -
            import  | 
| 25 | 
            -
             | 
| 26 | 
            -
            from  | 
| 27 | 
            -
            from f5_tts. | 
| 28 | 
            -
             | 
| 29 | 
            -
             | 
| 30 | 
            -
            training_process = None
         | 
| 31 | 
            -
            system = platform.system()
         | 
| 32 | 
            -
            python_executable = sys.executable or "python"
         | 
| 33 | 
            -
            tts_api = None
         | 
| 34 | 
            -
            last_checkpoint = ""
         | 
| 35 | 
            -
            last_device = ""
         | 
| 36 | 
            -
             | 
| 37 | 
            -
            path_data = "data"
         | 
| 38 | 
            -
             | 
| 39 | 
            -
            device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
         | 
| 40 | 
            -
             | 
| 41 | 
            -
            pipe = None
         | 
| 42 | 
            -
             | 
| 43 | 
            -
             | 
| 44 | 
            -
            # Load metadata
         | 
| 45 | 
            -
            def get_audio_duration(audio_path):
         | 
| 46 | 
            -
                """Calculate the duration of an audio file."""
         | 
| 47 | 
            -
                audio, sample_rate = torchaudio.load(audio_path)
         | 
| 48 | 
            -
                num_channels = audio.shape[0]
         | 
| 49 | 
            -
                return audio.shape[1] / (sample_rate * num_channels)
         | 
| 50 | 
            -
             | 
| 51 | 
            -
             | 
| 52 | 
            -
            def clear_text(text):
         | 
| 53 | 
            -
                """Clean and prepare text by lowering the case and stripping whitespace."""
         | 
| 54 | 
            -
                return text.lower().strip()
         | 
| 55 | 
            -
             | 
| 56 | 
            -
             | 
| 57 | 
            -
            def get_rms(
         | 
| 58 | 
            -
                y,
         | 
| 59 | 
            -
                frame_length=2048,
         | 
| 60 | 
            -
                hop_length=512,
         | 
| 61 | 
            -
                pad_mode="constant",
         | 
| 62 | 
            -
            ):  # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
         | 
| 63 | 
            -
                padding = (int(frame_length // 2), int(frame_length // 2))
         | 
| 64 | 
            -
                y = np.pad(y, padding, mode=pad_mode)
         | 
| 65 | 
            -
             | 
| 66 | 
            -
                axis = -1
         | 
| 67 | 
            -
                # put our new within-frame axis at the end for now
         | 
| 68 | 
            -
                out_strides = y.strides + tuple([y.strides[axis]])
         | 
| 69 | 
            -
                # Reduce the shape on the framing axis
         | 
| 70 | 
            -
                x_shape_trimmed = list(y.shape)
         | 
| 71 | 
            -
                x_shape_trimmed[axis] -= frame_length - 1
         | 
| 72 | 
            -
                out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
         | 
| 73 | 
            -
                xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
         | 
| 74 | 
            -
                if axis < 0:
         | 
| 75 | 
            -
                    target_axis = axis - 1
         | 
| 76 | 
            -
                else:
         | 
| 77 | 
            -
                    target_axis = axis + 1
         | 
| 78 | 
            -
                xw = np.moveaxis(xw, -1, target_axis)
         | 
| 79 | 
            -
                # Downsample along the target axis
         | 
| 80 | 
            -
                slices = [slice(None)] * xw.ndim
         | 
| 81 | 
            -
                slices[axis] = slice(0, None, hop_length)
         | 
| 82 | 
            -
                x = xw[tuple(slices)]
         | 
| 83 | 
            -
             | 
| 84 | 
            -
                # Calculate power
         | 
| 85 | 
            -
                power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
         | 
| 86 | 
            -
             | 
| 87 | 
            -
                return np.sqrt(power)
         | 
| 88 | 
            -
             | 
| 89 | 
            -
             | 
| 90 | 
            -
            class Slicer:  # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
         | 
| 91 | 
            -
                def __init__(
         | 
| 92 | 
            -
                    self,
         | 
| 93 | 
            -
                    sr: int,
         | 
| 94 | 
            -
                    threshold: float = -40.0,
         | 
| 95 | 
            -
                    min_length: int = 2000,
         | 
| 96 | 
            -
                    min_interval: int = 300,
         | 
| 97 | 
            -
                    hop_size: int = 20,
         | 
| 98 | 
            -
                    max_sil_kept: int = 2000,
         | 
| 99 | 
            -
                ):
         | 
| 100 | 
            -
                    if not min_length >= min_interval >= hop_size:
         | 
| 101 | 
            -
                        raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size")
         | 
| 102 | 
            -
                    if not max_sil_kept >= hop_size:
         | 
| 103 | 
            -
                        raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size")
         | 
| 104 | 
            -
                    min_interval = sr * min_interval / 1000
         | 
| 105 | 
            -
                    self.threshold = 10 ** (threshold / 20.0)
         | 
| 106 | 
            -
                    self.hop_size = round(sr * hop_size / 1000)
         | 
| 107 | 
            -
                    self.win_size = min(round(min_interval), 4 * self.hop_size)
         | 
| 108 | 
            -
                    self.min_length = round(sr * min_length / 1000 / self.hop_size)
         | 
| 109 | 
            -
                    self.min_interval = round(min_interval / self.hop_size)
         | 
| 110 | 
            -
                    self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
         | 
| 111 | 
            -
             | 
| 112 | 
            -
                def _apply_slice(self, waveform, begin, end):
         | 
| 113 | 
            -
                    if len(waveform.shape) > 1:
         | 
| 114 | 
            -
                        return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)]
         | 
| 115 | 
            -
                    else:
         | 
| 116 | 
            -
                        return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)]
         | 
| 117 | 
            -
             | 
| 118 | 
            -
                # @timeit
         | 
| 119 | 
            -
                def slice(self, waveform):
         | 
| 120 | 
            -
                    if len(waveform.shape) > 1:
         | 
| 121 | 
            -
                        samples = waveform.mean(axis=0)
         | 
| 122 | 
            -
                    else:
         | 
| 123 | 
            -
                        samples = waveform
         | 
| 124 | 
            -
                    if samples.shape[0] <= self.min_length:
         | 
| 125 | 
            -
                        return [waveform]
         | 
| 126 | 
            -
                    rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
         | 
| 127 | 
            -
                    sil_tags = []
         | 
| 128 | 
            -
                    silence_start = None
         | 
| 129 | 
            -
                    clip_start = 0
         | 
| 130 | 
            -
                    for i, rms in enumerate(rms_list):
         | 
| 131 | 
            -
                        # Keep looping while frame is silent.
         | 
| 132 | 
            -
                        if rms < self.threshold:
         | 
| 133 | 
            -
                            # Record start of silent frames.
         | 
| 134 | 
            -
                            if silence_start is None:
         | 
| 135 | 
            -
                                silence_start = i
         | 
| 136 | 
            -
                            continue
         | 
| 137 | 
            -
                        # Keep looping while frame is not silent and silence start has not been recorded.
         | 
| 138 | 
            -
                        if silence_start is None:
         | 
| 139 | 
            -
                            continue
         | 
| 140 | 
            -
                        # Clear recorded silence start if interval is not enough or clip is too short
         | 
| 141 | 
            -
                        is_leading_silence = silence_start == 0 and i > self.max_sil_kept
         | 
| 142 | 
            -
                        need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
         | 
| 143 | 
            -
                        if not is_leading_silence and not need_slice_middle:
         | 
| 144 | 
            -
                            silence_start = None
         | 
| 145 | 
            -
                            continue
         | 
| 146 | 
            -
                        # Need slicing. Record the range of silent frames to be removed.
         | 
| 147 | 
            -
                        if i - silence_start <= self.max_sil_kept:
         | 
| 148 | 
            -
                            pos = rms_list[silence_start : i + 1].argmin() + silence_start
         | 
| 149 | 
            -
                            if silence_start == 0:
         | 
| 150 | 
            -
                                sil_tags.append((0, pos))
         | 
| 151 | 
            -
                            else:
         | 
| 152 | 
            -
                                sil_tags.append((pos, pos))
         | 
| 153 | 
            -
                            clip_start = pos
         | 
| 154 | 
            -
                        elif i - silence_start <= self.max_sil_kept * 2:
         | 
| 155 | 
            -
                            pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
         | 
| 156 | 
            -
                            pos += i - self.max_sil_kept
         | 
| 157 | 
            -
                            pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
         | 
| 158 | 
            -
                            pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
         | 
| 159 | 
            -
                            if silence_start == 0:
         | 
| 160 | 
            -
                                sil_tags.append((0, pos_r))
         | 
| 161 | 
            -
                                clip_start = pos_r
         | 
| 162 | 
            -
                            else:
         | 
| 163 | 
            -
                                sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
         | 
| 164 | 
            -
                                clip_start = max(pos_r, pos)
         | 
| 165 | 
            -
                        else:
         | 
| 166 | 
            -
                            pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
         | 
| 167 | 
            -
                            pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
         | 
| 168 | 
            -
                            if silence_start == 0:
         | 
| 169 | 
            -
                                sil_tags.append((0, pos_r))
         | 
| 170 | 
            -
                            else:
         | 
| 171 | 
            -
                                sil_tags.append((pos_l, pos_r))
         | 
| 172 | 
            -
                            clip_start = pos_r
         | 
| 173 | 
            -
                        silence_start = None
         | 
| 174 | 
            -
                    # Deal with trailing silence.
         | 
| 175 | 
            -
                    total_frames = rms_list.shape[0]
         | 
| 176 | 
            -
                    if silence_start is not None and total_frames - silence_start >= self.min_interval:
         | 
| 177 | 
            -
                        silence_end = min(total_frames, silence_start + self.max_sil_kept)
         | 
| 178 | 
            -
                        pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
         | 
| 179 | 
            -
                        sil_tags.append((pos, total_frames + 1))
         | 
| 180 | 
            -
                    # Apply and return slices.
         | 
| 181 | 
            -
                    ####ι³ι’+θ΅·ε§ζΆι΄+η»ζ’ζΆι΄
         | 
| 182 | 
            -
                    if len(sil_tags) == 0:
         | 
| 183 | 
            -
                        return [[waveform, 0, int(total_frames * self.hop_size)]]
         | 
| 184 | 
            -
                    else:
         | 
| 185 | 
            -
                        chunks = []
         | 
| 186 | 
            -
                        if sil_tags[0][0] > 0:
         | 
| 187 | 
            -
                            chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)])
         | 
| 188 | 
            -
                        for i in range(len(sil_tags) - 1):
         | 
| 189 | 
            -
                            chunks.append(
         | 
| 190 | 
            -
                                [
         | 
| 191 | 
            -
                                    self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),
         | 
| 192 | 
            -
                                    int(sil_tags[i][1] * self.hop_size),
         | 
| 193 | 
            -
                                    int(sil_tags[i + 1][0] * self.hop_size),
         | 
| 194 | 
            -
                                ]
         | 
| 195 | 
            -
                            )
         | 
| 196 | 
            -
                        if sil_tags[-1][1] < total_frames:
         | 
| 197 | 
            -
                            chunks.append(
         | 
| 198 | 
            -
                                [
         | 
| 199 | 
            -
                                    self._apply_slice(waveform, sil_tags[-1][1], total_frames),
         | 
| 200 | 
            -
                                    int(sil_tags[-1][1] * self.hop_size),
         | 
| 201 | 
            -
                                    int(total_frames * self.hop_size),
         | 
| 202 | 
            -
                                ]
         | 
| 203 | 
            -
                            )
         | 
| 204 | 
            -
                        return chunks
         | 
| 205 | 
            -
             | 
| 206 | 
            -
             | 
| 207 | 
            -
            # terminal
         | 
| 208 | 
            -
            def terminate_process_tree(pid, including_parent=True):
         | 
| 209 | 
            -
                try:
         | 
| 210 | 
            -
                    parent = psutil.Process(pid)
         | 
| 211 | 
            -
                except psutil.NoSuchProcess:
         | 
| 212 | 
            -
                    # Process already terminated
         | 
| 213 | 
            -
                    return
         | 
| 214 | 
            -
             | 
| 215 | 
            -
                children = parent.children(recursive=True)
         | 
| 216 | 
            -
                for child in children:
         | 
| 217 | 
            -
                    try:
         | 
| 218 | 
            -
                        os.kill(child.pid, signal.SIGTERM)  # or signal.SIGKILL
         | 
| 219 | 
            -
                    except OSError:
         | 
| 220 | 
            -
                        pass
         | 
| 221 | 
            -
                if including_parent:
         | 
| 222 | 
            -
                    try:
         | 
| 223 | 
            -
                        os.kill(parent.pid, signal.SIGTERM)  # or signal.SIGKILL
         | 
| 224 | 
            -
                    except OSError:
         | 
| 225 | 
            -
                        pass
         | 
| 226 | 
            -
             | 
| 227 | 
            -
             | 
| 228 | 
            -
            def terminate_process(pid):
         | 
| 229 | 
            -
                if system == "Windows":
         | 
| 230 | 
            -
                    cmd = f"taskkill /t /f /pid {pid}"
         | 
| 231 | 
            -
                    os.system(cmd)
         | 
| 232 | 
            -
                else:
         | 
| 233 | 
            -
                    terminate_process_tree(pid)
         | 
| 234 | 
            -
             | 
| 235 | 
            -
             | 
| 236 | 
            -
            def start_training(
         | 
| 237 | 
            -
                dataset_name="",
         | 
| 238 | 
            -
                exp_name="F5TTS_Base",
         | 
| 239 | 
            -
                learning_rate=1e-4,
         | 
| 240 | 
            -
                batch_size_per_gpu=400,
         | 
| 241 | 
            -
                batch_size_type="frame",
         | 
| 242 | 
            -
                max_samples=64,
         | 
| 243 | 
            -
                grad_accumulation_steps=1,
         | 
| 244 | 
            -
                max_grad_norm=1.0,
         | 
| 245 | 
            -
                epochs=11,
         | 
| 246 | 
            -
                num_warmup_updates=200,
         | 
| 247 | 
            -
                save_per_updates=400,
         | 
| 248 | 
            -
                last_per_steps=800,
         | 
| 249 | 
            -
                finetune=True,
         | 
| 250 | 
            -
            ):
         | 
| 251 | 
            -
                global training_process, tts_api
         | 
| 252 | 
            -
             | 
| 253 | 
            -
                if tts_api is not None:
         | 
| 254 | 
            -
                    del tts_api
         | 
| 255 | 
            -
                    gc.collect()
         | 
| 256 | 
            -
                    torch.cuda.empty_cache()
         | 
| 257 | 
            -
                    tts_api = None
         | 
| 258 | 
            -
             | 
| 259 | 
            -
                path_project = os.path.join(path_data, dataset_name + "_pinyin")
         | 
| 260 | 
            -
             | 
| 261 | 
            -
                if not os.path.isdir(path_project):
         | 
| 262 | 
            -
                    yield (
         | 
| 263 | 
            -
                        f"There is not project with name {dataset_name}",
         | 
| 264 | 
            -
                        gr.update(interactive=True),
         | 
| 265 | 
            -
                        gr.update(interactive=False),
         | 
| 266 | 
            -
                    )
         | 
| 267 | 
            -
                    return
         | 
| 268 | 
            -
             | 
| 269 | 
            -
                file_raw = os.path.join(path_project, "raw.arrow")
         | 
| 270 | 
            -
                if not os.path.isfile(file_raw):
         | 
| 271 | 
            -
                    yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update(interactive=False)
         | 
| 272 | 
            -
                    return
         | 
| 273 | 
            -
             | 
| 274 | 
            -
                # Check if a training process is already running
         | 
| 275 | 
            -
                if training_process is not None:
         | 
| 276 | 
            -
                    return "Train run already!", gr.update(interactive=False), gr.update(interactive=True)
         | 
| 277 | 
            -
             | 
| 278 | 
            -
                yield "start train", gr.update(interactive=False), gr.update(interactive=False)
         | 
| 279 | 
            -
             | 
| 280 | 
            -
                # Command to run the training script with the specified arguments
         | 
| 281 | 
            -
                cmd = (
         | 
| 282 | 
            -
                    f"accelerate launch finetune-cli.py --exp_name {exp_name} "
         | 
| 283 | 
            -
                    f"--learning_rate {learning_rate} "
         | 
| 284 | 
            -
                    f"--batch_size_per_gpu {batch_size_per_gpu} "
         | 
| 285 | 
            -
                    f"--batch_size_type {batch_size_type} "
         | 
| 286 | 
            -
                    f"--max_samples {max_samples} "
         | 
| 287 | 
            -
                    f"--grad_accumulation_steps {grad_accumulation_steps} "
         | 
| 288 | 
            -
                    f"--max_grad_norm {max_grad_norm} "
         | 
| 289 | 
            -
                    f"--epochs {epochs} "
         | 
| 290 | 
            -
                    f"--num_warmup_updates {num_warmup_updates} "
         | 
| 291 | 
            -
                    f"--save_per_updates {save_per_updates} "
         | 
| 292 | 
            -
                    f"--last_per_steps {last_per_steps} "
         | 
| 293 | 
            -
                    f"--dataset_name {dataset_name}"
         | 
| 294 | 
            -
                )
         | 
| 295 | 
            -
                if finetune:
         | 
| 296 | 
            -
                    cmd += f" --finetune {finetune}"
         | 
| 297 | 
            -
             | 
| 298 | 
            -
                print(cmd)
         | 
| 299 | 
            -
             | 
| 300 | 
            -
                try:
         | 
| 301 | 
            -
                    # Start the training process
         | 
| 302 | 
            -
                    training_process = subprocess.Popen(cmd, shell=True)
         | 
| 303 | 
            -
             | 
| 304 | 
            -
                    time.sleep(5)
         | 
| 305 | 
            -
                    yield "train start", gr.update(interactive=False), gr.update(interactive=True)
         | 
| 306 | 
            -
             | 
| 307 | 
            -
                    # Wait for the training process to finish
         | 
| 308 | 
            -
                    training_process.wait()
         | 
| 309 | 
            -
                    time.sleep(1)
         | 
| 310 | 
            -
             | 
| 311 | 
            -
                    if training_process is None:
         | 
| 312 | 
            -
                        text_info = "train stop"
         | 
| 313 | 
            -
                    else:
         | 
| 314 | 
            -
                        text_info = "train complete !"
         | 
| 315 | 
            -
             | 
| 316 | 
            -
                except Exception as e:  # Catch all exceptions
         | 
| 317 | 
            -
                    # Ensure that we reset the training process variable in case of an error
         | 
| 318 | 
            -
                    text_info = f"An error occurred: {str(e)}"
         | 
| 319 | 
            -
             | 
| 320 | 
            -
                training_process = None
         | 
| 321 | 
            -
             | 
| 322 | 
            -
                yield text_info, gr.update(interactive=True), gr.update(interactive=False)
         | 
| 323 | 
            -
             | 
| 324 | 
            -
             | 
| 325 | 
            -
            def stop_training():
         | 
| 326 | 
            -
                global training_process
         | 
| 327 | 
            -
                if training_process is None:
         | 
| 328 | 
            -
                    return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
         | 
| 329 | 
            -
                terminate_process_tree(training_process.pid)
         | 
| 330 | 
            -
                training_process = None
         | 
| 331 | 
            -
                return "train stop", gr.update(interactive=True), gr.update(interactive=False)
         | 
| 332 | 
            -
             | 
| 333 | 
            -
             | 
| 334 | 
            -
            def create_data_project(name):
         | 
| 335 | 
            -
                name += "_pinyin"
         | 
| 336 | 
            -
                os.makedirs(os.path.join(path_data, name), exist_ok=True)
         | 
| 337 | 
            -
                os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
         | 
| 338 | 
            -
             | 
| 339 | 
            -
             | 
| 340 | 
            -
            def transcribe(file_audio, language="english"):
         | 
| 341 | 
            -
                global pipe
         | 
| 342 | 
            -
             | 
| 343 | 
            -
                if pipe is None:
         | 
| 344 | 
            -
                    pipe = pipeline(
         | 
| 345 | 
            -
                        "automatic-speech-recognition",
         | 
| 346 | 
            -
                        model="openai/whisper-large-v3-turbo",
         | 
| 347 | 
            -
                        torch_dtype=torch.float16,
         | 
| 348 | 
            -
                        device=device,
         | 
| 349 | 
            -
                    )
         | 
| 350 | 
            -
             | 
| 351 | 
            -
                text_transcribe = pipe(
         | 
| 352 | 
            -
                    file_audio,
         | 
| 353 | 
            -
                    chunk_length_s=30,
         | 
| 354 | 
            -
                    batch_size=128,
         | 
| 355 | 
            -
                    generate_kwargs={"task": "transcribe", "language": language},
         | 
| 356 | 
            -
                    return_timestamps=False,
         | 
| 357 | 
            -
                )["text"].strip()
         | 
| 358 | 
            -
                return text_transcribe
         | 
| 359 | 
            -
             | 
| 360 | 
            -
             | 
| 361 | 
            -
            def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
         | 
| 362 | 
            -
                name_project += "_pinyin"
         | 
| 363 | 
            -
                path_project = os.path.join(path_data, name_project)
         | 
| 364 | 
            -
                path_dataset = os.path.join(path_project, "dataset")
         | 
| 365 | 
            -
                path_project_wavs = os.path.join(path_project, "wavs")
         | 
| 366 | 
            -
                file_metadata = os.path.join(path_project, "metadata.csv")
         | 
| 367 | 
            -
             | 
| 368 | 
            -
                if audio_files is None:
         | 
| 369 | 
            -
                    return "You need to load an audio file."
         | 
| 370 | 
            -
             | 
| 371 | 
            -
                if os.path.isdir(path_project_wavs):
         | 
| 372 | 
            -
                    shutil.rmtree(path_project_wavs)
         | 
| 373 | 
            -
             | 
| 374 | 
            -
                if os.path.isfile(file_metadata):
         | 
| 375 | 
            -
                    os.remove(file_metadata)
         | 
| 376 | 
            -
             | 
| 377 | 
            -
                os.makedirs(path_project_wavs, exist_ok=True)
         | 
| 378 | 
            -
             | 
| 379 | 
            -
                if user:
         | 
| 380 | 
            -
                    file_audios = [
         | 
| 381 | 
            -
                        file
         | 
| 382 | 
            -
                        for format in ("*.wav", "*.ogg", "*.opus", "*.mp3", "*.flac")
         | 
| 383 | 
            -
                        for file in glob(os.path.join(path_dataset, format))
         | 
| 384 | 
            -
                    ]
         | 
| 385 | 
            -
                    if file_audios == []:
         | 
| 386 | 
            -
                        return "No audio file was found in the dataset."
         | 
| 387 | 
            -
                else:
         | 
| 388 | 
            -
                    file_audios = audio_files
         | 
| 389 | 
            -
             | 
| 390 | 
            -
                alpha = 0.5
         | 
| 391 | 
            -
                _max = 1.0
         | 
| 392 | 
            -
                slicer = Slicer(24000)
         | 
| 393 | 
            -
             | 
| 394 | 
            -
                num = 0
         | 
| 395 | 
            -
                error_num = 0
         | 
| 396 | 
            -
                data = ""
         | 
| 397 | 
            -
                for file_audio in progress.tqdm(file_audios, desc="transcribe files", total=len((file_audios))):
         | 
| 398 | 
            -
                    audio, _ = librosa.load(file_audio, sr=24000, mono=True)
         | 
| 399 | 
            -
             | 
| 400 | 
            -
                    list_slicer = slicer.slice(audio)
         | 
| 401 | 
            -
                    for chunk, start, end in progress.tqdm(list_slicer, total=len(list_slicer), desc="slicer files"):
         | 
| 402 | 
            -
                        name_segment = os.path.join(f"segment_{num}")
         | 
| 403 | 
            -
                        file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
         | 
| 404 | 
            -
             | 
| 405 | 
            -
                        tmp_max = np.abs(chunk).max()
         | 
| 406 | 
            -
                        if tmp_max > 1:
         | 
| 407 | 
            -
                            chunk /= tmp_max
         | 
| 408 | 
            -
                        chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
         | 
| 409 | 
            -
                        wavfile.write(file_segment, 24000, (chunk * 32767).astype(np.int16))
         | 
| 410 | 
            -
             | 
| 411 | 
            -
                        try:
         | 
| 412 | 
            -
                            text = transcribe(file_segment, language)
         | 
| 413 | 
            -
                            text = text.lower().strip().replace('"', "")
         | 
| 414 | 
            -
             | 
| 415 | 
            -
                            data += f"{name_segment}|{text}\n"
         | 
| 416 | 
            -
             | 
| 417 | 
            -
                            num += 1
         | 
| 418 | 
            -
                        except:  # noqa: E722
         | 
| 419 | 
            -
                            error_num += 1
         | 
| 420 | 
            -
             | 
| 421 | 
            -
                with open(file_metadata, "w", encoding="utf-8") as f:
         | 
| 422 | 
            -
                    f.write(data)
         | 
| 423 | 
            -
             | 
| 424 | 
            -
                if error_num != []:
         | 
| 425 | 
            -
                    error_text = f"\nerror files : {error_num}"
         | 
| 426 | 
            -
                else:
         | 
| 427 | 
            -
                    error_text = ""
         | 
| 428 | 
            -
             | 
| 429 | 
            -
                return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
         | 
| 430 | 
            -
             | 
| 431 | 
            -
             | 
| 432 | 
            -
            def format_seconds_to_hms(seconds):
         | 
| 433 | 
            -
                hours = int(seconds / 3600)
         | 
| 434 | 
            -
                minutes = int((seconds % 3600) / 60)
         | 
| 435 | 
            -
                seconds = seconds % 60
         | 
| 436 | 
            -
                return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
         | 
| 437 | 
            -
             | 
| 438 | 
            -
             | 
| 439 | 
            -
            def create_metadata(name_project, progress=gr.Progress()):
         | 
| 440 | 
            -
                name_project += "_pinyin"
         | 
| 441 | 
            -
                path_project = os.path.join(path_data, name_project)
         | 
| 442 | 
            -
                path_project_wavs = os.path.join(path_project, "wavs")
         | 
| 443 | 
            -
                file_metadata = os.path.join(path_project, "metadata.csv")
         | 
| 444 | 
            -
                file_raw = os.path.join(path_project, "raw.arrow")
         | 
| 445 | 
            -
                file_duration = os.path.join(path_project, "duration.json")
         | 
| 446 | 
            -
                file_vocab = os.path.join(path_project, "vocab.txt")
         | 
| 447 | 
            -
             | 
| 448 | 
            -
                if not os.path.isfile(file_metadata):
         | 
| 449 | 
            -
                    return "The file was not found in " + file_metadata
         | 
| 450 | 
            -
             | 
| 451 | 
            -
                with open(file_metadata, "r", encoding="utf-8") as f:
         | 
| 452 | 
            -
                    data = f.read()
         | 
| 453 | 
            -
             | 
| 454 | 
            -
                audio_path_list = []
         | 
| 455 | 
            -
                text_list = []
         | 
| 456 | 
            -
                duration_list = []
         | 
| 457 | 
            -
             | 
| 458 | 
            -
                count = data.split("\n")
         | 
| 459 | 
            -
                lenght = 0
         | 
| 460 | 
            -
                result = []
         | 
| 461 | 
            -
                error_files = []
         | 
| 462 | 
            -
                for line in progress.tqdm(data.split("\n"), total=count):
         | 
| 463 | 
            -
                    sp_line = line.split("|")
         | 
| 464 | 
            -
                    if len(sp_line) != 2:
         | 
| 465 | 
            -
                        continue
         | 
| 466 | 
            -
                    name_audio, text = sp_line[:2]
         | 
| 467 | 
            -
             | 
| 468 | 
            -
                    file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
         | 
| 469 | 
            -
             | 
| 470 | 
            -
                    if not os.path.isfile(file_audio):
         | 
| 471 | 
            -
                        error_files.append(file_audio)
         | 
| 472 | 
            -
                        continue
         | 
| 473 | 
            -
             | 
| 474 | 
            -
                    duraction = get_audio_duration(file_audio)
         | 
| 475 | 
            -
                    if duraction < 2 and duraction > 15:
         | 
| 476 | 
            -
                        continue
         | 
| 477 | 
            -
                    if len(text) < 4:
         | 
| 478 | 
            -
                        continue
         | 
| 479 | 
            -
             | 
| 480 | 
            -
                    text = clear_text(text)
         | 
| 481 | 
            -
                    text = convert_char_to_pinyin([text], polyphone=True)[0]
         | 
| 482 | 
            -
             | 
| 483 | 
            -
                    audio_path_list.append(file_audio)
         | 
| 484 | 
            -
                    duration_list.append(duraction)
         | 
| 485 | 
            -
                    text_list.append(text)
         | 
| 486 | 
            -
             | 
| 487 | 
            -
                    result.append({"audio_path": file_audio, "text": text, "duration": duraction})
         | 
| 488 | 
            -
             | 
| 489 | 
            -
                    lenght += duraction
         | 
| 490 | 
            -
             | 
| 491 | 
            -
                if duration_list == []:
         | 
| 492 | 
            -
                    error_files_text = "\n".join(error_files)
         | 
| 493 | 
            -
                    return f"Error: No audio files found in the specified path : \n{error_files_text}"
         | 
| 494 | 
            -
             | 
| 495 | 
            -
                min_second = round(min(duration_list), 2)
         | 
| 496 | 
            -
                max_second = round(max(duration_list), 2)
         | 
| 497 | 
            -
             | 
| 498 | 
            -
                with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
         | 
| 499 | 
            -
                    for line in progress.tqdm(result, total=len(result), desc="prepare data"):
         | 
| 500 | 
            -
                        writer.write(line)
         | 
| 501 | 
            -
             | 
| 502 | 
            -
                with open(file_duration, "w", encoding="utf-8") as f:
         | 
| 503 | 
            -
                    json.dump({"duration": duration_list}, f, ensure_ascii=False)
         | 
| 504 | 
            -
             | 
| 505 | 
            -
                file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
         | 
| 506 | 
            -
                if not os.path.isfile(file_vocab_finetune):
         | 
| 507 | 
            -
                    return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
         | 
| 508 | 
            -
                shutil.copy2(file_vocab_finetune, file_vocab)
         | 
| 509 | 
            -
             | 
| 510 | 
            -
                if error_files != []:
         | 
| 511 | 
            -
                    error_text = "error files\n" + "\n".join(error_files)
         | 
| 512 | 
            -
                else:
         | 
| 513 | 
            -
                    error_text = ""
         | 
| 514 | 
            -
             | 
| 515 | 
            -
                return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
         | 
| 516 | 
            -
             | 
| 517 | 
            -
             | 
| 518 | 
            -
            def check_user(value):
         | 
| 519 | 
            -
                return gr.update(visible=not value), gr.update(visible=value)
         | 
| 520 | 
            -
             | 
| 521 | 
            -
             | 
| 522 | 
            -
            def calculate_train(
         | 
| 523 | 
            -
                name_project,
         | 
| 524 | 
            -
                batch_size_type,
         | 
| 525 | 
            -
                max_samples,
         | 
| 526 | 
            -
                learning_rate,
         | 
| 527 | 
            -
                num_warmup_updates,
         | 
| 528 | 
            -
                save_per_updates,
         | 
| 529 | 
            -
                last_per_steps,
         | 
| 530 | 
            -
                finetune,
         | 
| 531 | 
            -
            ):
         | 
| 532 | 
            -
                name_project += "_pinyin"
         | 
| 533 | 
            -
                path_project = os.path.join(path_data, name_project)
         | 
| 534 | 
            -
                file_duraction = os.path.join(path_project, "duration.json")
         | 
| 535 | 
            -
             | 
| 536 | 
            -
                if not os.path.isfile(file_duraction):
         | 
| 537 | 
            -
                    return (
         | 
| 538 | 
            -
                        1000,
         | 
| 539 | 
            -
                        max_samples,
         | 
| 540 | 
            -
                        num_warmup_updates,
         | 
| 541 | 
            -
                        save_per_updates,
         | 
| 542 | 
            -
                        last_per_steps,
         | 
| 543 | 
            -
                        "project not found !",
         | 
| 544 | 
            -
                        learning_rate,
         | 
| 545 | 
            -
                    )
         | 
| 546 | 
            -
             | 
| 547 | 
            -
                with open(file_duraction, "r") as file:
         | 
| 548 | 
            -
                    data = json.load(file)
         | 
| 549 | 
            -
             | 
| 550 | 
            -
                duration_list = data["duration"]
         | 
| 551 | 
            -
             | 
| 552 | 
            -
                samples = len(duration_list)
         | 
| 553 | 
            -
             | 
| 554 | 
            -
                if torch.cuda.is_available():
         | 
| 555 | 
            -
                    gpu_properties = torch.cuda.get_device_properties(0)
         | 
| 556 | 
            -
                    total_memory = gpu_properties.total_memory / (1024**3)
         | 
| 557 | 
            -
                elif torch.backends.mps.is_available():
         | 
| 558 | 
            -
                    total_memory = psutil.virtual_memory().available / (1024**3)
         | 
| 559 | 
            -
             | 
| 560 | 
            -
                if batch_size_type == "frame":
         | 
| 561 | 
            -
                    batch = int(total_memory * 0.5)
         | 
| 562 | 
            -
                    batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
         | 
| 563 | 
            -
                    batch_size_per_gpu = int(38400 / batch)
         | 
| 564 | 
            -
                else:
         | 
| 565 | 
            -
                    batch_size_per_gpu = int(total_memory / 8)
         | 
| 566 | 
            -
                    batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
         | 
| 567 | 
            -
                    batch = batch_size_per_gpu
         | 
| 568 | 
            -
             | 
| 569 | 
            -
                if batch_size_per_gpu <= 0:
         | 
| 570 | 
            -
                    batch_size_per_gpu = 1
         | 
| 571 | 
            -
             | 
| 572 | 
            -
                if samples < 64:
         | 
| 573 | 
            -
                    max_samples = int(samples * 0.25)
         | 
| 574 | 
            -
                else:
         | 
| 575 | 
            -
                    max_samples = 64
         | 
| 576 | 
            -
             | 
| 577 | 
            -
                num_warmup_updates = int(samples * 0.05)
         | 
| 578 | 
            -
                save_per_updates = int(samples * 0.10)
         | 
| 579 | 
            -
                last_per_steps = int(save_per_updates * 5)
         | 
| 580 | 
            -
             | 
| 581 | 
            -
                max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
         | 
| 582 | 
            -
                num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
         | 
| 583 | 
            -
                save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
         | 
| 584 | 
            -
                last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
         | 
| 585 | 
            -
             | 
| 586 | 
            -
                if finetune:
         | 
| 587 | 
            -
                    learning_rate = 1e-5
         | 
| 588 | 
            -
                else:
         | 
| 589 | 
            -
                    learning_rate = 7.5e-5
         | 
| 590 | 
            -
             | 
| 591 | 
            -
                return batch_size_per_gpu, max_samples, num_warmup_updates, save_per_updates, last_per_steps, samples, learning_rate
         | 
| 592 | 
            -
             | 
| 593 | 
            -
             | 
| 594 | 
            -
            def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
         | 
| 595 | 
            -
                try:
         | 
| 596 | 
            -
                    checkpoint = torch.load(checkpoint_path)
         | 
| 597 | 
            -
                    print("Original Checkpoint Keys:", checkpoint.keys())
         | 
| 598 | 
            -
             | 
| 599 | 
            -
                    ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
         | 
| 600 | 
            -
             | 
| 601 | 
            -
                    if ema_model_state_dict is not None:
         | 
| 602 | 
            -
                        new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
         | 
| 603 | 
            -
                        torch.save(new_checkpoint, new_checkpoint_path)
         | 
| 604 | 
            -
                        return f"New checkpoint saved at: {new_checkpoint_path}"
         | 
| 605 | 
            -
                    else:
         | 
| 606 | 
            -
                        return "No 'ema_model_state_dict' found in the checkpoint."
         | 
| 607 | 
            -
             | 
| 608 | 
            -
                except Exception as e:
         | 
| 609 | 
            -
                    return f"An error occurred: {e}"
         | 
| 610 | 
            -
             | 
| 611 | 
            -
             | 
| 612 | 
            -
            def vocab_check(project_name):
         | 
| 613 | 
            -
                name_project = project_name + "_pinyin"
         | 
| 614 | 
            -
                path_project = os.path.join(path_data, name_project)
         | 
| 615 | 
            -
             | 
| 616 | 
            -
                file_metadata = os.path.join(path_project, "metadata.csv")
         | 
| 617 | 
            -
             | 
| 618 | 
            -
                file_vocab = "data/Emilia_ZH_EN_pinyin/vocab.txt"
         | 
| 619 | 
            -
                if not os.path.isfile(file_vocab):
         | 
| 620 | 
            -
                    return f"the file {file_vocab} not found !"
         | 
| 621 | 
            -
             | 
| 622 | 
            -
                with open(file_vocab, "r", encoding="utf-8") as f:
         | 
| 623 | 
            -
                    data = f.read()
         | 
| 624 | 
            -
             | 
| 625 | 
            -
                vocab = data.split("\n")
         | 
| 626 | 
            -
             | 
| 627 | 
            -
                if not os.path.isfile(file_metadata):
         | 
| 628 | 
            -
                    return f"the file {file_metadata} not found !"
         | 
| 629 | 
            -
             | 
| 630 | 
            -
                with open(file_metadata, "r", encoding="utf-8") as f:
         | 
| 631 | 
            -
                    data = f.read()
         | 
| 632 | 
            -
             | 
| 633 | 
            -
                miss_symbols = []
         | 
| 634 | 
            -
                miss_symbols_keep = {}
         | 
| 635 | 
            -
                for item in data.split("\n"):
         | 
| 636 | 
            -
                    sp = item.split("|")
         | 
| 637 | 
            -
                    if len(sp) != 2:
         | 
| 638 | 
            -
                        continue
         | 
| 639 | 
            -
             | 
| 640 | 
            -
                    text = sp[1].lower().strip()
         | 
| 641 | 
            -
             | 
| 642 | 
            -
                    for t in text:
         | 
| 643 | 
            -
                        if t not in vocab and t not in miss_symbols_keep:
         | 
| 644 | 
            -
                            miss_symbols.append(t)
         | 
| 645 | 
            -
                            miss_symbols_keep[t] = t
         | 
| 646 | 
            -
                if miss_symbols == []:
         | 
| 647 | 
            -
                    info = "You can train using your language !"
         | 
| 648 | 
            -
                else:
         | 
| 649 | 
            -
                    info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
         | 
| 650 | 
            -
             | 
| 651 | 
            -
                return info
         | 
| 652 | 
            -
             | 
| 653 | 
            -
             | 
| 654 | 
            -
            def get_random_sample_prepare(project_name):
         | 
| 655 | 
            -
                name_project = project_name + "_pinyin"
         | 
| 656 | 
            -
                path_project = os.path.join(path_data, name_project)
         | 
| 657 | 
            -
                file_arrow = os.path.join(path_project, "raw.arrow")
         | 
| 658 | 
            -
                if not os.path.isfile(file_arrow):
         | 
| 659 | 
            -
                    return "", None
         | 
| 660 | 
            -
                dataset = Dataset_.from_file(file_arrow)
         | 
| 661 | 
            -
                random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0])
         | 
| 662 | 
            -
                text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]"
         | 
| 663 | 
            -
                audio_path = random_sample["audio_path"][0]
         | 
| 664 | 
            -
                return text, audio_path
         | 
| 665 | 
            -
             | 
| 666 | 
            -
             | 
| 667 | 
            -
            def get_random_sample_transcribe(project_name):
         | 
| 668 | 
            -
                name_project = project_name + "_pinyin"
         | 
| 669 | 
            -
                path_project = os.path.join(path_data, name_project)
         | 
| 670 | 
            -
                file_metadata = os.path.join(path_project, "metadata.csv")
         | 
| 671 | 
            -
                if not os.path.isfile(file_metadata):
         | 
| 672 | 
            -
                    return "", None
         | 
| 673 | 
            -
             | 
| 674 | 
            -
                data = ""
         | 
| 675 | 
            -
                with open(file_metadata, "r", encoding="utf-8") as f:
         | 
| 676 | 
            -
                    data = f.read()
         | 
| 677 | 
            -
             | 
| 678 | 
            -
                list_data = []
         | 
| 679 | 
            -
                for item in data.split("\n"):
         | 
| 680 | 
            -
                    sp = item.split("|")
         | 
| 681 | 
            -
                    if len(sp) != 2:
         | 
| 682 | 
            -
                        continue
         | 
| 683 | 
            -
                    list_data.append([os.path.join(path_project, "wavs", sp[0] + ".wav"), sp[1]])
         | 
| 684 | 
            -
             | 
| 685 | 
            -
                if list_data == []:
         | 
| 686 | 
            -
                    return "", None
         | 
| 687 | 
            -
             | 
| 688 | 
            -
                random_item = random.choice(list_data)
         | 
| 689 | 
            -
             | 
| 690 | 
            -
                return random_item[1], random_item[0]
         | 
| 691 | 
            -
             | 
| 692 | 
            -
             | 
| 693 | 
            -
            def get_random_sample_infer(project_name):
         | 
| 694 | 
            -
                text, audio = get_random_sample_transcribe(project_name)
         | 
| 695 | 
            -
                return (
         | 
| 696 | 
            -
                    text,
         | 
| 697 | 
            -
                    text,
         | 
| 698 | 
            -
                    audio,
         | 
| 699 | 
            -
                )
         | 
| 700 | 
            -
             | 
| 701 | 
            -
             | 
| 702 | 
            -
            def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
         | 
| 703 | 
            -
                global last_checkpoint, last_device, tts_api
         | 
| 704 | 
            -
             | 
| 705 | 
            -
                if not os.path.isfile(file_checkpoint):
         | 
| 706 | 
            -
                    return None
         | 
| 707 | 
            -
             | 
| 708 | 
            -
                if training_process is not None:
         | 
| 709 | 
            -
                    device_test = "cpu"
         | 
| 710 | 
            -
                else:
         | 
| 711 | 
            -
                    device_test = None
         | 
| 712 | 
            -
             | 
| 713 | 
            -
                if last_checkpoint != file_checkpoint or last_device != device_test:
         | 
| 714 | 
            -
                    if last_checkpoint != file_checkpoint:
         | 
| 715 | 
            -
                        last_checkpoint = file_checkpoint
         | 
| 716 | 
            -
                    if last_device != device_test:
         | 
| 717 | 
            -
                        last_device = device_test
         | 
| 718 | 
            -
             | 
| 719 | 
            -
                    tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test)
         | 
| 720 | 
            -
             | 
| 721 | 
            -
                    print("update", device_test, file_checkpoint)
         | 
| 722 | 
            -
             | 
| 723 | 
            -
                with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
         | 
| 724 | 
            -
                    tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
         | 
| 725 | 
            -
                    return f.name
         | 
| 726 | 
            -
             | 
| 727 | 
            -
             | 
| 728 | 
            -
            with gr.Blocks() as app:
         | 
| 729 | 
            -
                with gr.Row():
         | 
| 730 | 
            -
                    project_name = gr.Textbox(label="project name", value="my_speak")
         | 
| 731 | 
            -
                    bt_create = gr.Button("create new project")
         | 
| 732 | 
            -
             | 
| 733 | 
            -
                bt_create.click(fn=create_data_project, inputs=[project_name])
         | 
| 734 | 
            -
             | 
| 735 | 
            -
                with gr.Tabs():
         | 
| 736 | 
            -
                    with gr.TabItem("transcribe Data"):
         | 
| 737 | 
            -
                        ch_manual = gr.Checkbox(label="user", value=False)
         | 
| 738 | 
            -
             | 
| 739 | 
            -
                        mark_info_transcribe = gr.Markdown(
         | 
| 740 | 
            -
                            """```plaintext    
         | 
| 741 | 
            -
                 Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory. 
         | 
| 742 | 
            -
                             
         | 
| 743 | 
            -
                 my_speak/
         | 
| 744 | 
            -
                 β
         | 
| 745 | 
            -
                 βββ dataset/
         | 
| 746 | 
            -
                     βββ audio1.wav
         | 
| 747 | 
            -
                     βββ audio2.wav
         | 
| 748 | 
            -
                     ...
         | 
| 749 | 
            -
                 ```""",
         | 
| 750 | 
            -
                            visible=False,
         | 
| 751 | 
            -
                        )
         | 
| 752 | 
            -
             | 
| 753 | 
            -
                        audio_speaker = gr.File(label="voice", type="filepath", file_count="multiple")
         | 
| 754 | 
            -
                        txt_lang = gr.Text(label="Language", value="english")
         | 
| 755 | 
            -
                        bt_transcribe = bt_create = gr.Button("transcribe")
         | 
| 756 | 
            -
                        txt_info_transcribe = gr.Text(label="info", value="")
         | 
| 757 | 
            -
                        bt_transcribe.click(
         | 
| 758 | 
            -
                            fn=transcribe_all,
         | 
| 759 | 
            -
                            inputs=[project_name, audio_speaker, txt_lang, ch_manual],
         | 
| 760 | 
            -
                            outputs=[txt_info_transcribe],
         | 
| 761 | 
            -
                        )
         | 
| 762 | 
            -
                        ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
         | 
| 763 | 
            -
             | 
| 764 | 
            -
                        random_sample_transcribe = gr.Button("random sample")
         | 
| 765 | 
            -
             | 
| 766 | 
            -
                        with gr.Row():
         | 
| 767 | 
            -
                            random_text_transcribe = gr.Text(label="Text")
         | 
| 768 | 
            -
                            random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
         | 
| 769 | 
            -
             | 
| 770 | 
            -
                        random_sample_transcribe.click(
         | 
| 771 | 
            -
                            fn=get_random_sample_transcribe,
         | 
| 772 | 
            -
                            inputs=[project_name],
         | 
| 773 | 
            -
                            outputs=[random_text_transcribe, random_audio_transcribe],
         | 
| 774 | 
            -
                        )
         | 
| 775 | 
            -
             | 
| 776 | 
            -
                    with gr.TabItem("prepare Data"):
         | 
| 777 | 
            -
                        gr.Markdown(
         | 
| 778 | 
            -
                            """```plaintext    
         | 
| 779 | 
            -
                 place all your wavs folder and your metadata.csv file in {your name project}                                 
         | 
| 780 | 
            -
                 my_speak/
         | 
| 781 | 
            -
                 β
         | 
| 782 | 
            -
                 βββ wavs/
         | 
| 783 | 
            -
                 β   βββ audio1.wav
         | 
| 784 | 
            -
                 β   βββ audio2.wav
         | 
| 785 | 
            -
                 |   ...
         | 
| 786 | 
            -
                 β
         | 
| 787 | 
            -
                 βββ metadata.csv
         | 
| 788 | 
            -
                  
         | 
| 789 | 
            -
                 file format metadata.csv
         | 
| 790 | 
            -
             | 
| 791 | 
            -
                 audio1|text1
         | 
| 792 | 
            -
                 audio2|text1
         | 
| 793 | 
            -
                 ...
         | 
| 794 | 
            -
             | 
| 795 | 
            -
                 ```"""
         | 
| 796 | 
            -
                        )
         | 
| 797 | 
            -
             | 
| 798 | 
            -
                        bt_prepare = bt_create = gr.Button("prepare")
         | 
| 799 | 
            -
                        txt_info_prepare = gr.Text(label="info", value="")
         | 
| 800 | 
            -
                        bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
         | 
| 801 | 
            -
             | 
| 802 | 
            -
                        random_sample_prepare = gr.Button("random sample")
         | 
| 803 | 
            -
             | 
| 804 | 
            -
                        with gr.Row():
         | 
| 805 | 
            -
                            random_text_prepare = gr.Text(label="Pinyin")
         | 
| 806 | 
            -
                            random_audio_prepare = gr.Audio(label="Audio", type="filepath")
         | 
| 807 | 
            -
             | 
| 808 | 
            -
                        random_sample_prepare.click(
         | 
| 809 | 
            -
                            fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare]
         | 
| 810 | 
            -
                        )
         | 
| 811 | 
            -
             | 
| 812 | 
            -
                    with gr.TabItem("train Data"):
         | 
| 813 | 
            -
                        with gr.Row():
         | 
| 814 | 
            -
                            bt_calculate = bt_create = gr.Button("Auto Settings")
         | 
| 815 | 
            -
                            ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
         | 
| 816 | 
            -
                            lb_samples = gr.Label(label="samples")
         | 
| 817 | 
            -
                            batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
         | 
| 818 | 
            -
             | 
| 819 | 
            -
                        with gr.Row():
         | 
| 820 | 
            -
                            exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
         | 
| 821 | 
            -
                            learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
         | 
| 822 | 
            -
             | 
| 823 | 
            -
                        with gr.Row():
         | 
| 824 | 
            -
                            batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
         | 
| 825 | 
            -
                            max_samples = gr.Number(label="Max Samples", value=64)
         | 
| 826 | 
            -
             | 
| 827 | 
            -
                        with gr.Row():
         | 
| 828 | 
            -
                            grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
         | 
| 829 | 
            -
                            max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
         | 
| 830 | 
            -
             | 
| 831 | 
            -
                        with gr.Row():
         | 
| 832 | 
            -
                            epochs = gr.Number(label="Epochs", value=10)
         | 
| 833 | 
            -
                            num_warmup_updates = gr.Number(label="Warmup Updates", value=5)
         | 
| 834 | 
            -
             | 
| 835 | 
            -
                        with gr.Row():
         | 
| 836 | 
            -
                            save_per_updates = gr.Number(label="Save per Updates", value=10)
         | 
| 837 | 
            -
                            last_per_steps = gr.Number(label="Last per Steps", value=50)
         | 
| 838 | 
            -
             | 
| 839 | 
            -
                        with gr.Row():
         | 
| 840 | 
            -
                            start_button = gr.Button("Start Training")
         | 
| 841 | 
            -
                            stop_button = gr.Button("Stop Training", interactive=False)
         | 
| 842 | 
            -
             | 
| 843 | 
            -
                        txt_info_train = gr.Text(label="info", value="")
         | 
| 844 | 
            -
                        start_button.click(
         | 
| 845 | 
            -
                            fn=start_training,
         | 
| 846 | 
            -
                            inputs=[
         | 
| 847 | 
            -
                                project_name,
         | 
| 848 | 
            -
                                exp_name,
         | 
| 849 | 
            -
                                learning_rate,
         | 
| 850 | 
            -
                                batch_size_per_gpu,
         | 
| 851 | 
            -
                                batch_size_type,
         | 
| 852 | 
            -
                                max_samples,
         | 
| 853 | 
            -
                                grad_accumulation_steps,
         | 
| 854 | 
            -
                                max_grad_norm,
         | 
| 855 | 
            -
                                epochs,
         | 
| 856 | 
            -
                                num_warmup_updates,
         | 
| 857 | 
            -
                                save_per_updates,
         | 
| 858 | 
            -
                                last_per_steps,
         | 
| 859 | 
            -
                                ch_finetune,
         | 
| 860 | 
            -
                            ],
         | 
| 861 | 
            -
                            outputs=[txt_info_train, start_button, stop_button],
         | 
| 862 | 
            -
                        )
         | 
| 863 | 
            -
                        stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
         | 
| 864 | 
            -
                        bt_calculate.click(
         | 
| 865 | 
            -
                            fn=calculate_train,
         | 
| 866 | 
            -
                            inputs=[
         | 
| 867 | 
            -
                                project_name,
         | 
| 868 | 
            -
                                batch_size_type,
         | 
| 869 | 
            -
                                max_samples,
         | 
| 870 | 
            -
                                learning_rate,
         | 
| 871 | 
            -
                                num_warmup_updates,
         | 
| 872 | 
            -
                                save_per_updates,
         | 
| 873 | 
            -
                                last_per_steps,
         | 
| 874 | 
            -
                                ch_finetune,
         | 
| 875 | 
            -
                            ],
         | 
| 876 | 
            -
                            outputs=[
         | 
| 877 | 
            -
                                batch_size_per_gpu,
         | 
| 878 | 
            -
                                max_samples,
         | 
| 879 | 
            -
                                num_warmup_updates,
         | 
| 880 | 
            -
                                save_per_updates,
         | 
| 881 | 
            -
                                last_per_steps,
         | 
| 882 | 
            -
                                lb_samples,
         | 
| 883 | 
            -
                                learning_rate,
         | 
| 884 | 
            -
                            ],
         | 
| 885 | 
            -
                        )
         | 
| 886 | 
            -
             | 
| 887 | 
            -
                    with gr.TabItem("reduse checkpoint"):
         | 
| 888 | 
            -
                        txt_path_checkpoint = gr.Text(label="path checkpoint :")
         | 
| 889 | 
            -
                        txt_path_checkpoint_small = gr.Text(label="path output :")
         | 
| 890 | 
            -
                        txt_info_reduse = gr.Text(label="info", value="")
         | 
| 891 | 
            -
                        reduse_button = gr.Button("reduse")
         | 
| 892 | 
            -
                        reduse_button.click(
         | 
| 893 | 
            -
                            fn=extract_and_save_ema_model,
         | 
| 894 | 
            -
                            inputs=[txt_path_checkpoint, txt_path_checkpoint_small],
         | 
| 895 | 
            -
                            outputs=[txt_info_reduse],
         | 
| 896 | 
            -
                        )
         | 
| 897 | 
            -
             | 
| 898 | 
            -
                    with gr.TabItem("vocab check experiment"):
         | 
| 899 | 
            -
                        check_button = gr.Button("check vocab")
         | 
| 900 | 
            -
                        txt_info_check = gr.Text(label="info", value="")
         | 
| 901 | 
            -
                        check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
         | 
| 902 | 
            -
             | 
| 903 | 
            -
                    with gr.TabItem("test model"):
         | 
| 904 | 
            -
                        exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
         | 
| 905 | 
            -
                        nfe_step = gr.Number(label="n_step", value=32)
         | 
| 906 | 
            -
                        file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="")
         | 
| 907 | 
            -
             | 
| 908 | 
            -
                        random_sample_infer = gr.Button("random sample")
         | 
| 909 | 
            -
             | 
| 910 | 
            -
                        ref_text = gr.Textbox(label="ref text")
         | 
| 911 | 
            -
                        ref_audio = gr.Audio(label="audio ref", type="filepath")
         | 
| 912 | 
            -
                        gen_text = gr.Textbox(label="gen text")
         | 
| 913 | 
            -
                        random_sample_infer.click(
         | 
| 914 | 
            -
                            fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio]
         | 
| 915 | 
            -
                        )
         | 
| 916 | 
            -
                        check_button_infer = gr.Button("infer")
         | 
| 917 | 
            -
                        gen_audio = gr.Audio(label="audio gen", type="filepath")
         | 
| 918 | 
            -
             | 
| 919 | 
            -
                        check_button_infer.click(
         | 
| 920 | 
            -
                            fn=infer,
         | 
| 921 | 
            -
                            inputs=[file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step],
         | 
| 922 | 
            -
                            outputs=[gen_audio],
         | 
| 923 | 
            -
                        )
         | 
| 924 | 
            -
             | 
| 925 | 
            -
             | 
| 926 | 
            -
            @click.command()
         | 
| 927 | 
            -
            @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
         | 
| 928 | 
            -
            @click.option("--host", "-H", default=None, help="Host to run the app on")
         | 
| 929 | 
            -
            @click.option(
         | 
| 930 | 
            -
                "--share",
         | 
| 931 | 
            -
                "-s",
         | 
| 932 | 
            -
                default=False,
         | 
| 933 | 
            -
                is_flag=True,
         | 
| 934 | 
            -
                help="Share the app via Gradio share link",
         | 
| 935 | 
            -
            )
         | 
| 936 | 
            -
            @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
         | 
| 937 | 
            -
            def main(port, host, share, api):
         | 
| 938 | 
            -
                global app
         | 
| 939 | 
            -
                print("Starting app...")
         | 
| 940 | 
            -
                app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
         | 
| 941 | 
            -
             | 
| 942 | 
            -
             | 
| 943 | 
            -
            if __name__ == "__main__":
         | 
| 944 | 
            -
                main()
         | 
|  | |
| 1 | 
            +
            import gc
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            import platform
         | 
| 5 | 
            +
            import psutil
         | 
| 6 | 
            +
            import random
         | 
| 7 | 
            +
            import signal
         | 
| 8 | 
            +
            import shutil
         | 
| 9 | 
            +
            import subprocess
         | 
| 10 | 
            +
            import sys
         | 
| 11 | 
            +
            import tempfile
         | 
| 12 | 
            +
            import time
         | 
| 13 | 
            +
            from glob import glob
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            import click
         | 
| 16 | 
            +
            import gradio as gr
         | 
| 17 | 
            +
            import librosa
         | 
| 18 | 
            +
            import numpy as np
         | 
| 19 | 
            +
            import torch
         | 
| 20 | 
            +
            import torchaudio
         | 
| 21 | 
            +
            from datasets import Dataset as Dataset_
         | 
| 22 | 
            +
            from datasets.arrow_writer import ArrowWriter
         | 
| 23 | 
            +
            from scipy.io import wavfile
         | 
| 24 | 
            +
            from transformers import pipeline
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            from f5_tts.api import F5TTS
         | 
| 27 | 
            +
            from f5_tts.model.utils import convert_char_to_pinyin
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            training_process = None
         | 
| 31 | 
            +
            system = platform.system()
         | 
| 32 | 
            +
            python_executable = sys.executable or "python"
         | 
| 33 | 
            +
            tts_api = None
         | 
| 34 | 
            +
            last_checkpoint = ""
         | 
| 35 | 
            +
            last_device = ""
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            path_data = "data"
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            pipe = None
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            # Load metadata
         | 
| 45 | 
            +
            def get_audio_duration(audio_path):
         | 
| 46 | 
            +
                """Calculate the duration of an audio file."""
         | 
| 47 | 
            +
                audio, sample_rate = torchaudio.load(audio_path)
         | 
| 48 | 
            +
                num_channels = audio.shape[0]
         | 
| 49 | 
            +
                return audio.shape[1] / (sample_rate * num_channels)
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def clear_text(text):
         | 
| 53 | 
            +
                """Clean and prepare text by lowering the case and stripping whitespace."""
         | 
| 54 | 
            +
                return text.lower().strip()
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def get_rms(
         | 
| 58 | 
            +
                y,
         | 
| 59 | 
            +
                frame_length=2048,
         | 
| 60 | 
            +
                hop_length=512,
         | 
| 61 | 
            +
                pad_mode="constant",
         | 
| 62 | 
            +
            ):  # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
         | 
| 63 | 
            +
                padding = (int(frame_length // 2), int(frame_length // 2))
         | 
| 64 | 
            +
                y = np.pad(y, padding, mode=pad_mode)
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                axis = -1
         | 
| 67 | 
            +
                # put our new within-frame axis at the end for now
         | 
| 68 | 
            +
                out_strides = y.strides + tuple([y.strides[axis]])
         | 
| 69 | 
            +
                # Reduce the shape on the framing axis
         | 
| 70 | 
            +
                x_shape_trimmed = list(y.shape)
         | 
| 71 | 
            +
                x_shape_trimmed[axis] -= frame_length - 1
         | 
| 72 | 
            +
                out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
         | 
| 73 | 
            +
                xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
         | 
| 74 | 
            +
                if axis < 0:
         | 
| 75 | 
            +
                    target_axis = axis - 1
         | 
| 76 | 
            +
                else:
         | 
| 77 | 
            +
                    target_axis = axis + 1
         | 
| 78 | 
            +
                xw = np.moveaxis(xw, -1, target_axis)
         | 
| 79 | 
            +
                # Downsample along the target axis
         | 
| 80 | 
            +
                slices = [slice(None)] * xw.ndim
         | 
| 81 | 
            +
                slices[axis] = slice(0, None, hop_length)
         | 
| 82 | 
            +
                x = xw[tuple(slices)]
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                # Calculate power
         | 
| 85 | 
            +
                power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                return np.sqrt(power)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            class Slicer:  # https://github.com/RVC-Boss/GPT-SoVITS/blob/main/tools/slicer2.py
         | 
| 91 | 
            +
                def __init__(
         | 
| 92 | 
            +
                    self,
         | 
| 93 | 
            +
                    sr: int,
         | 
| 94 | 
            +
                    threshold: float = -40.0,
         | 
| 95 | 
            +
                    min_length: int = 2000,
         | 
| 96 | 
            +
                    min_interval: int = 300,
         | 
| 97 | 
            +
                    hop_size: int = 20,
         | 
| 98 | 
            +
                    max_sil_kept: int = 2000,
         | 
| 99 | 
            +
                ):
         | 
| 100 | 
            +
                    if not min_length >= min_interval >= hop_size:
         | 
| 101 | 
            +
                        raise ValueError("The following condition must be satisfied: min_length >= min_interval >= hop_size")
         | 
| 102 | 
            +
                    if not max_sil_kept >= hop_size:
         | 
| 103 | 
            +
                        raise ValueError("The following condition must be satisfied: max_sil_kept >= hop_size")
         | 
| 104 | 
            +
                    min_interval = sr * min_interval / 1000
         | 
| 105 | 
            +
                    self.threshold = 10 ** (threshold / 20.0)
         | 
| 106 | 
            +
                    self.hop_size = round(sr * hop_size / 1000)
         | 
| 107 | 
            +
                    self.win_size = min(round(min_interval), 4 * self.hop_size)
         | 
| 108 | 
            +
                    self.min_length = round(sr * min_length / 1000 / self.hop_size)
         | 
| 109 | 
            +
                    self.min_interval = round(min_interval / self.hop_size)
         | 
| 110 | 
            +
                    self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def _apply_slice(self, waveform, begin, end):
         | 
| 113 | 
            +
                    if len(waveform.shape) > 1:
         | 
| 114 | 
            +
                        return waveform[:, begin * self.hop_size : min(waveform.shape[1], end * self.hop_size)]
         | 
| 115 | 
            +
                    else:
         | 
| 116 | 
            +
                        return waveform[begin * self.hop_size : min(waveform.shape[0], end * self.hop_size)]
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                # @timeit
         | 
| 119 | 
            +
                def slice(self, waveform):
         | 
| 120 | 
            +
                    if len(waveform.shape) > 1:
         | 
| 121 | 
            +
                        samples = waveform.mean(axis=0)
         | 
| 122 | 
            +
                    else:
         | 
| 123 | 
            +
                        samples = waveform
         | 
| 124 | 
            +
                    if samples.shape[0] <= self.min_length:
         | 
| 125 | 
            +
                        return [waveform]
         | 
| 126 | 
            +
                    rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
         | 
| 127 | 
            +
                    sil_tags = []
         | 
| 128 | 
            +
                    silence_start = None
         | 
| 129 | 
            +
                    clip_start = 0
         | 
| 130 | 
            +
                    for i, rms in enumerate(rms_list):
         | 
| 131 | 
            +
                        # Keep looping while frame is silent.
         | 
| 132 | 
            +
                        if rms < self.threshold:
         | 
| 133 | 
            +
                            # Record start of silent frames.
         | 
| 134 | 
            +
                            if silence_start is None:
         | 
| 135 | 
            +
                                silence_start = i
         | 
| 136 | 
            +
                            continue
         | 
| 137 | 
            +
                        # Keep looping while frame is not silent and silence start has not been recorded.
         | 
| 138 | 
            +
                        if silence_start is None:
         | 
| 139 | 
            +
                            continue
         | 
| 140 | 
            +
                        # Clear recorded silence start if interval is not enough or clip is too short
         | 
| 141 | 
            +
                        is_leading_silence = silence_start == 0 and i > self.max_sil_kept
         | 
| 142 | 
            +
                        need_slice_middle = i - silence_start >= self.min_interval and i - clip_start >= self.min_length
         | 
| 143 | 
            +
                        if not is_leading_silence and not need_slice_middle:
         | 
| 144 | 
            +
                            silence_start = None
         | 
| 145 | 
            +
                            continue
         | 
| 146 | 
            +
                        # Need slicing. Record the range of silent frames to be removed.
         | 
| 147 | 
            +
                        if i - silence_start <= self.max_sil_kept:
         | 
| 148 | 
            +
                            pos = rms_list[silence_start : i + 1].argmin() + silence_start
         | 
| 149 | 
            +
                            if silence_start == 0:
         | 
| 150 | 
            +
                                sil_tags.append((0, pos))
         | 
| 151 | 
            +
                            else:
         | 
| 152 | 
            +
                                sil_tags.append((pos, pos))
         | 
| 153 | 
            +
                            clip_start = pos
         | 
| 154 | 
            +
                        elif i - silence_start <= self.max_sil_kept * 2:
         | 
| 155 | 
            +
                            pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
         | 
| 156 | 
            +
                            pos += i - self.max_sil_kept
         | 
| 157 | 
            +
                            pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
         | 
| 158 | 
            +
                            pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
         | 
| 159 | 
            +
                            if silence_start == 0:
         | 
| 160 | 
            +
                                sil_tags.append((0, pos_r))
         | 
| 161 | 
            +
                                clip_start = pos_r
         | 
| 162 | 
            +
                            else:
         | 
| 163 | 
            +
                                sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
         | 
| 164 | 
            +
                                clip_start = max(pos_r, pos)
         | 
| 165 | 
            +
                        else:
         | 
| 166 | 
            +
                            pos_l = rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start
         | 
| 167 | 
            +
                            pos_r = rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept
         | 
| 168 | 
            +
                            if silence_start == 0:
         | 
| 169 | 
            +
                                sil_tags.append((0, pos_r))
         | 
| 170 | 
            +
                            else:
         | 
| 171 | 
            +
                                sil_tags.append((pos_l, pos_r))
         | 
| 172 | 
            +
                            clip_start = pos_r
         | 
| 173 | 
            +
                        silence_start = None
         | 
| 174 | 
            +
                    # Deal with trailing silence.
         | 
| 175 | 
            +
                    total_frames = rms_list.shape[0]
         | 
| 176 | 
            +
                    if silence_start is not None and total_frames - silence_start >= self.min_interval:
         | 
| 177 | 
            +
                        silence_end = min(total_frames, silence_start + self.max_sil_kept)
         | 
| 178 | 
            +
                        pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
         | 
| 179 | 
            +
                        sil_tags.append((pos, total_frames + 1))
         | 
| 180 | 
            +
                    # Apply and return slices.
         | 
| 181 | 
            +
                    ####ι³ι’+θ΅·ε§ζΆι΄+η»ζ’ζΆι΄
         | 
| 182 | 
            +
                    if len(sil_tags) == 0:
         | 
| 183 | 
            +
                        return [[waveform, 0, int(total_frames * self.hop_size)]]
         | 
| 184 | 
            +
                    else:
         | 
| 185 | 
            +
                        chunks = []
         | 
| 186 | 
            +
                        if sil_tags[0][0] > 0:
         | 
| 187 | 
            +
                            chunks.append([self._apply_slice(waveform, 0, sil_tags[0][0]), 0, int(sil_tags[0][0] * self.hop_size)])
         | 
| 188 | 
            +
                        for i in range(len(sil_tags) - 1):
         | 
| 189 | 
            +
                            chunks.append(
         | 
| 190 | 
            +
                                [
         | 
| 191 | 
            +
                                    self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]),
         | 
| 192 | 
            +
                                    int(sil_tags[i][1] * self.hop_size),
         | 
| 193 | 
            +
                                    int(sil_tags[i + 1][0] * self.hop_size),
         | 
| 194 | 
            +
                                ]
         | 
| 195 | 
            +
                            )
         | 
| 196 | 
            +
                        if sil_tags[-1][1] < total_frames:
         | 
| 197 | 
            +
                            chunks.append(
         | 
| 198 | 
            +
                                [
         | 
| 199 | 
            +
                                    self._apply_slice(waveform, sil_tags[-1][1], total_frames),
         | 
| 200 | 
            +
                                    int(sil_tags[-1][1] * self.hop_size),
         | 
| 201 | 
            +
                                    int(total_frames * self.hop_size),
         | 
| 202 | 
            +
                                ]
         | 
| 203 | 
            +
                            )
         | 
| 204 | 
            +
                        return chunks
         | 
| 205 | 
            +
             | 
| 206 | 
            +
             | 
| 207 | 
            +
            # terminal
         | 
| 208 | 
            +
            def terminate_process_tree(pid, including_parent=True):
         | 
| 209 | 
            +
                try:
         | 
| 210 | 
            +
                    parent = psutil.Process(pid)
         | 
| 211 | 
            +
                except psutil.NoSuchProcess:
         | 
| 212 | 
            +
                    # Process already terminated
         | 
| 213 | 
            +
                    return
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                children = parent.children(recursive=True)
         | 
| 216 | 
            +
                for child in children:
         | 
| 217 | 
            +
                    try:
         | 
| 218 | 
            +
                        os.kill(child.pid, signal.SIGTERM)  # or signal.SIGKILL
         | 
| 219 | 
            +
                    except OSError:
         | 
| 220 | 
            +
                        pass
         | 
| 221 | 
            +
                if including_parent:
         | 
| 222 | 
            +
                    try:
         | 
| 223 | 
            +
                        os.kill(parent.pid, signal.SIGTERM)  # or signal.SIGKILL
         | 
| 224 | 
            +
                    except OSError:
         | 
| 225 | 
            +
                        pass
         | 
| 226 | 
            +
             | 
| 227 | 
            +
             | 
| 228 | 
            +
            def terminate_process(pid):
         | 
| 229 | 
            +
                if system == "Windows":
         | 
| 230 | 
            +
                    cmd = f"taskkill /t /f /pid {pid}"
         | 
| 231 | 
            +
                    os.system(cmd)
         | 
| 232 | 
            +
                else:
         | 
| 233 | 
            +
                    terminate_process_tree(pid)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
             | 
| 236 | 
            +
            def start_training(
         | 
| 237 | 
            +
                dataset_name="",
         | 
| 238 | 
            +
                exp_name="F5TTS_Base",
         | 
| 239 | 
            +
                learning_rate=1e-4,
         | 
| 240 | 
            +
                batch_size_per_gpu=400,
         | 
| 241 | 
            +
                batch_size_type="frame",
         | 
| 242 | 
            +
                max_samples=64,
         | 
| 243 | 
            +
                grad_accumulation_steps=1,
         | 
| 244 | 
            +
                max_grad_norm=1.0,
         | 
| 245 | 
            +
                epochs=11,
         | 
| 246 | 
            +
                num_warmup_updates=200,
         | 
| 247 | 
            +
                save_per_updates=400,
         | 
| 248 | 
            +
                last_per_steps=800,
         | 
| 249 | 
            +
                finetune=True,
         | 
| 250 | 
            +
            ):
         | 
| 251 | 
            +
                global training_process, tts_api
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                if tts_api is not None:
         | 
| 254 | 
            +
                    del tts_api
         | 
| 255 | 
            +
                    gc.collect()
         | 
| 256 | 
            +
                    torch.cuda.empty_cache()
         | 
| 257 | 
            +
                    tts_api = None
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                path_project = os.path.join(path_data, dataset_name + "_pinyin")
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                if not os.path.isdir(path_project):
         | 
| 262 | 
            +
                    yield (
         | 
| 263 | 
            +
                        f"There is not project with name {dataset_name}",
         | 
| 264 | 
            +
                        gr.update(interactive=True),
         | 
| 265 | 
            +
                        gr.update(interactive=False),
         | 
| 266 | 
            +
                    )
         | 
| 267 | 
            +
                    return
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                file_raw = os.path.join(path_project, "raw.arrow")
         | 
| 270 | 
            +
                if not os.path.isfile(file_raw):
         | 
| 271 | 
            +
                    yield f"There is no file {file_raw}", gr.update(interactive=True), gr.update(interactive=False)
         | 
| 272 | 
            +
                    return
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                # Check if a training process is already running
         | 
| 275 | 
            +
                if training_process is not None:
         | 
| 276 | 
            +
                    return "Train run already!", gr.update(interactive=False), gr.update(interactive=True)
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                yield "start train", gr.update(interactive=False), gr.update(interactive=False)
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                # Command to run the training script with the specified arguments
         | 
| 281 | 
            +
                cmd = (
         | 
| 282 | 
            +
                    f"accelerate launch finetune-cli.py --exp_name {exp_name} "
         | 
| 283 | 
            +
                    f"--learning_rate {learning_rate} "
         | 
| 284 | 
            +
                    f"--batch_size_per_gpu {batch_size_per_gpu} "
         | 
| 285 | 
            +
                    f"--batch_size_type {batch_size_type} "
         | 
| 286 | 
            +
                    f"--max_samples {max_samples} "
         | 
| 287 | 
            +
                    f"--grad_accumulation_steps {grad_accumulation_steps} "
         | 
| 288 | 
            +
                    f"--max_grad_norm {max_grad_norm} "
         | 
| 289 | 
            +
                    f"--epochs {epochs} "
         | 
| 290 | 
            +
                    f"--num_warmup_updates {num_warmup_updates} "
         | 
| 291 | 
            +
                    f"--save_per_updates {save_per_updates} "
         | 
| 292 | 
            +
                    f"--last_per_steps {last_per_steps} "
         | 
| 293 | 
            +
                    f"--dataset_name {dataset_name}"
         | 
| 294 | 
            +
                )
         | 
| 295 | 
            +
                if finetune:
         | 
| 296 | 
            +
                    cmd += f" --finetune {finetune}"
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                print(cmd)
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                try:
         | 
| 301 | 
            +
                    # Start the training process
         | 
| 302 | 
            +
                    training_process = subprocess.Popen(cmd, shell=True)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    time.sleep(5)
         | 
| 305 | 
            +
                    yield "train start", gr.update(interactive=False), gr.update(interactive=True)
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    # Wait for the training process to finish
         | 
| 308 | 
            +
                    training_process.wait()
         | 
| 309 | 
            +
                    time.sleep(1)
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                    if training_process is None:
         | 
| 312 | 
            +
                        text_info = "train stop"
         | 
| 313 | 
            +
                    else:
         | 
| 314 | 
            +
                        text_info = "train complete !"
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                except Exception as e:  # Catch all exceptions
         | 
| 317 | 
            +
                    # Ensure that we reset the training process variable in case of an error
         | 
| 318 | 
            +
                    text_info = f"An error occurred: {str(e)}"
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                training_process = None
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                yield text_info, gr.update(interactive=True), gr.update(interactive=False)
         | 
| 323 | 
            +
             | 
| 324 | 
            +
             | 
| 325 | 
            +
            def stop_training():
         | 
| 326 | 
            +
                global training_process
         | 
| 327 | 
            +
                if training_process is None:
         | 
| 328 | 
            +
                    return "Train not run !", gr.update(interactive=True), gr.update(interactive=False)
         | 
| 329 | 
            +
                terminate_process_tree(training_process.pid)
         | 
| 330 | 
            +
                training_process = None
         | 
| 331 | 
            +
                return "train stop", gr.update(interactive=True), gr.update(interactive=False)
         | 
| 332 | 
            +
             | 
| 333 | 
            +
             | 
| 334 | 
            +
            def create_data_project(name):
         | 
| 335 | 
            +
                name += "_pinyin"
         | 
| 336 | 
            +
                os.makedirs(os.path.join(path_data, name), exist_ok=True)
         | 
| 337 | 
            +
                os.makedirs(os.path.join(path_data, name, "dataset"), exist_ok=True)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
             | 
| 340 | 
            +
            def transcribe(file_audio, language="english"):
         | 
| 341 | 
            +
                global pipe
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                if pipe is None:
         | 
| 344 | 
            +
                    pipe = pipeline(
         | 
| 345 | 
            +
                        "automatic-speech-recognition",
         | 
| 346 | 
            +
                        model="openai/whisper-large-v3-turbo",
         | 
| 347 | 
            +
                        torch_dtype=torch.float16,
         | 
| 348 | 
            +
                        device=device,
         | 
| 349 | 
            +
                    )
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                text_transcribe = pipe(
         | 
| 352 | 
            +
                    file_audio,
         | 
| 353 | 
            +
                    chunk_length_s=30,
         | 
| 354 | 
            +
                    batch_size=128,
         | 
| 355 | 
            +
                    generate_kwargs={"task": "transcribe", "language": language},
         | 
| 356 | 
            +
                    return_timestamps=False,
         | 
| 357 | 
            +
                )["text"].strip()
         | 
| 358 | 
            +
                return text_transcribe
         | 
| 359 | 
            +
             | 
| 360 | 
            +
             | 
| 361 | 
            +
            def transcribe_all(name_project, audio_files, language, user=False, progress=gr.Progress()):
         | 
| 362 | 
            +
                name_project += "_pinyin"
         | 
| 363 | 
            +
                path_project = os.path.join(path_data, name_project)
         | 
| 364 | 
            +
                path_dataset = os.path.join(path_project, "dataset")
         | 
| 365 | 
            +
                path_project_wavs = os.path.join(path_project, "wavs")
         | 
| 366 | 
            +
                file_metadata = os.path.join(path_project, "metadata.csv")
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                if audio_files is None:
         | 
| 369 | 
            +
                    return "You need to load an audio file."
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                if os.path.isdir(path_project_wavs):
         | 
| 372 | 
            +
                    shutil.rmtree(path_project_wavs)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                if os.path.isfile(file_metadata):
         | 
| 375 | 
            +
                    os.remove(file_metadata)
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                os.makedirs(path_project_wavs, exist_ok=True)
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                if user:
         | 
| 380 | 
            +
                    file_audios = [
         | 
| 381 | 
            +
                        file
         | 
| 382 | 
            +
                        for format in ("*.wav", "*.ogg", "*.opus", "*.mp3", "*.flac")
         | 
| 383 | 
            +
                        for file in glob(os.path.join(path_dataset, format))
         | 
| 384 | 
            +
                    ]
         | 
| 385 | 
            +
                    if file_audios == []:
         | 
| 386 | 
            +
                        return "No audio file was found in the dataset."
         | 
| 387 | 
            +
                else:
         | 
| 388 | 
            +
                    file_audios = audio_files
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                alpha = 0.5
         | 
| 391 | 
            +
                _max = 1.0
         | 
| 392 | 
            +
                slicer = Slicer(24000)
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                num = 0
         | 
| 395 | 
            +
                error_num = 0
         | 
| 396 | 
            +
                data = ""
         | 
| 397 | 
            +
                for file_audio in progress.tqdm(file_audios, desc="transcribe files", total=len((file_audios))):
         | 
| 398 | 
            +
                    audio, _ = librosa.load(file_audio, sr=24000, mono=True)
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                    list_slicer = slicer.slice(audio)
         | 
| 401 | 
            +
                    for chunk, start, end in progress.tqdm(list_slicer, total=len(list_slicer), desc="slicer files"):
         | 
| 402 | 
            +
                        name_segment = os.path.join(f"segment_{num}")
         | 
| 403 | 
            +
                        file_segment = os.path.join(path_project_wavs, f"{name_segment}.wav")
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                        tmp_max = np.abs(chunk).max()
         | 
| 406 | 
            +
                        if tmp_max > 1:
         | 
| 407 | 
            +
                            chunk /= tmp_max
         | 
| 408 | 
            +
                        chunk = (chunk / tmp_max * (_max * alpha)) + (1 - alpha) * chunk
         | 
| 409 | 
            +
                        wavfile.write(file_segment, 24000, (chunk * 32767).astype(np.int16))
         | 
| 410 | 
            +
             | 
| 411 | 
            +
                        try:
         | 
| 412 | 
            +
                            text = transcribe(file_segment, language)
         | 
| 413 | 
            +
                            text = text.lower().strip().replace('"', "")
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                            data += f"{name_segment}|{text}\n"
         | 
| 416 | 
            +
             | 
| 417 | 
            +
                            num += 1
         | 
| 418 | 
            +
                        except:  # noqa: E722
         | 
| 419 | 
            +
                            error_num += 1
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                with open(file_metadata, "w", encoding="utf-8") as f:
         | 
| 422 | 
            +
                    f.write(data)
         | 
| 423 | 
            +
             | 
| 424 | 
            +
                if error_num != []:
         | 
| 425 | 
            +
                    error_text = f"\nerror files : {error_num}"
         | 
| 426 | 
            +
                else:
         | 
| 427 | 
            +
                    error_text = ""
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                return f"transcribe complete samples : {num}\npath : {path_project_wavs}{error_text}"
         | 
| 430 | 
            +
             | 
| 431 | 
            +
             | 
| 432 | 
            +
            def format_seconds_to_hms(seconds):
         | 
| 433 | 
            +
                hours = int(seconds / 3600)
         | 
| 434 | 
            +
                minutes = int((seconds % 3600) / 60)
         | 
| 435 | 
            +
                seconds = seconds % 60
         | 
| 436 | 
            +
                return "{:02d}:{:02d}:{:02d}".format(hours, minutes, int(seconds))
         | 
| 437 | 
            +
             | 
| 438 | 
            +
             | 
| 439 | 
            +
            def create_metadata(name_project, progress=gr.Progress()):
         | 
| 440 | 
            +
                name_project += "_pinyin"
         | 
| 441 | 
            +
                path_project = os.path.join(path_data, name_project)
         | 
| 442 | 
            +
                path_project_wavs = os.path.join(path_project, "wavs")
         | 
| 443 | 
            +
                file_metadata = os.path.join(path_project, "metadata.csv")
         | 
| 444 | 
            +
                file_raw = os.path.join(path_project, "raw.arrow")
         | 
| 445 | 
            +
                file_duration = os.path.join(path_project, "duration.json")
         | 
| 446 | 
            +
                file_vocab = os.path.join(path_project, "vocab.txt")
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                if not os.path.isfile(file_metadata):
         | 
| 449 | 
            +
                    return "The file was not found in " + file_metadata
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                with open(file_metadata, "r", encoding="utf-8") as f:
         | 
| 452 | 
            +
                    data = f.read()
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                audio_path_list = []
         | 
| 455 | 
            +
                text_list = []
         | 
| 456 | 
            +
                duration_list = []
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                count = data.split("\n")
         | 
| 459 | 
            +
                lenght = 0
         | 
| 460 | 
            +
                result = []
         | 
| 461 | 
            +
                error_files = []
         | 
| 462 | 
            +
                for line in progress.tqdm(data.split("\n"), total=count):
         | 
| 463 | 
            +
                    sp_line = line.split("|")
         | 
| 464 | 
            +
                    if len(sp_line) != 2:
         | 
| 465 | 
            +
                        continue
         | 
| 466 | 
            +
                    name_audio, text = sp_line[:2]
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                    file_audio = os.path.join(path_project_wavs, name_audio + ".wav")
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                    if not os.path.isfile(file_audio):
         | 
| 471 | 
            +
                        error_files.append(file_audio)
         | 
| 472 | 
            +
                        continue
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                    duraction = get_audio_duration(file_audio)
         | 
| 475 | 
            +
                    if duraction < 2 and duraction > 15:
         | 
| 476 | 
            +
                        continue
         | 
| 477 | 
            +
                    if len(text) < 4:
         | 
| 478 | 
            +
                        continue
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                    text = clear_text(text)
         | 
| 481 | 
            +
                    text = convert_char_to_pinyin([text], polyphone=True)[0]
         | 
| 482 | 
            +
             | 
| 483 | 
            +
                    audio_path_list.append(file_audio)
         | 
| 484 | 
            +
                    duration_list.append(duraction)
         | 
| 485 | 
            +
                    text_list.append(text)
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                    result.append({"audio_path": file_audio, "text": text, "duration": duraction})
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    lenght += duraction
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                if duration_list == []:
         | 
| 492 | 
            +
                    error_files_text = "\n".join(error_files)
         | 
| 493 | 
            +
                    return f"Error: No audio files found in the specified path : \n{error_files_text}"
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                min_second = round(min(duration_list), 2)
         | 
| 496 | 
            +
                max_second = round(max(duration_list), 2)
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                with ArrowWriter(path=file_raw, writer_batch_size=1) as writer:
         | 
| 499 | 
            +
                    for line in progress.tqdm(result, total=len(result), desc="prepare data"):
         | 
| 500 | 
            +
                        writer.write(line)
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                with open(file_duration, "w", encoding="utf-8") as f:
         | 
| 503 | 
            +
                    json.dump({"duration": duration_list}, f, ensure_ascii=False)
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                file_vocab_finetune = "data/Emilia_ZH_EN_pinyin/vocab.txt"
         | 
| 506 | 
            +
                if not os.path.isfile(file_vocab_finetune):
         | 
| 507 | 
            +
                    return "Error: Vocabulary file 'Emilia_ZH_EN_pinyin' not found!"
         | 
| 508 | 
            +
                shutil.copy2(file_vocab_finetune, file_vocab)
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                if error_files != []:
         | 
| 511 | 
            +
                    error_text = "error files\n" + "\n".join(error_files)
         | 
| 512 | 
            +
                else:
         | 
| 513 | 
            +
                    error_text = ""
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                return f"prepare complete \nsamples : {len(text_list)}\ntime data : {format_seconds_to_hms(lenght)}\nmin sec : {min_second}\nmax sec : {max_second}\nfile_arrow : {file_raw}\n{error_text}"
         | 
| 516 | 
            +
             | 
| 517 | 
            +
             | 
| 518 | 
            +
            def check_user(value):
         | 
| 519 | 
            +
                return gr.update(visible=not value), gr.update(visible=value)
         | 
| 520 | 
            +
             | 
| 521 | 
            +
             | 
| 522 | 
            +
            def calculate_train(
         | 
| 523 | 
            +
                name_project,
         | 
| 524 | 
            +
                batch_size_type,
         | 
| 525 | 
            +
                max_samples,
         | 
| 526 | 
            +
                learning_rate,
         | 
| 527 | 
            +
                num_warmup_updates,
         | 
| 528 | 
            +
                save_per_updates,
         | 
| 529 | 
            +
                last_per_steps,
         | 
| 530 | 
            +
                finetune,
         | 
| 531 | 
            +
            ):
         | 
| 532 | 
            +
                name_project += "_pinyin"
         | 
| 533 | 
            +
                path_project = os.path.join(path_data, name_project)
         | 
| 534 | 
            +
                file_duraction = os.path.join(path_project, "duration.json")
         | 
| 535 | 
            +
             | 
| 536 | 
            +
                if not os.path.isfile(file_duraction):
         | 
| 537 | 
            +
                    return (
         | 
| 538 | 
            +
                        1000,
         | 
| 539 | 
            +
                        max_samples,
         | 
| 540 | 
            +
                        num_warmup_updates,
         | 
| 541 | 
            +
                        save_per_updates,
         | 
| 542 | 
            +
                        last_per_steps,
         | 
| 543 | 
            +
                        "project not found !",
         | 
| 544 | 
            +
                        learning_rate,
         | 
| 545 | 
            +
                    )
         | 
| 546 | 
            +
             | 
| 547 | 
            +
                with open(file_duraction, "r") as file:
         | 
| 548 | 
            +
                    data = json.load(file)
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                duration_list = data["duration"]
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                samples = len(duration_list)
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                if torch.cuda.is_available():
         | 
| 555 | 
            +
                    gpu_properties = torch.cuda.get_device_properties(0)
         | 
| 556 | 
            +
                    total_memory = gpu_properties.total_memory / (1024**3)
         | 
| 557 | 
            +
                elif torch.backends.mps.is_available():
         | 
| 558 | 
            +
                    total_memory = psutil.virtual_memory().available / (1024**3)
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                if batch_size_type == "frame":
         | 
| 561 | 
            +
                    batch = int(total_memory * 0.5)
         | 
| 562 | 
            +
                    batch = (lambda num: num + 1 if num % 2 != 0 else num)(batch)
         | 
| 563 | 
            +
                    batch_size_per_gpu = int(38400 / batch)
         | 
| 564 | 
            +
                else:
         | 
| 565 | 
            +
                    batch_size_per_gpu = int(total_memory / 8)
         | 
| 566 | 
            +
                    batch_size_per_gpu = (lambda num: num + 1 if num % 2 != 0 else num)(batch_size_per_gpu)
         | 
| 567 | 
            +
                    batch = batch_size_per_gpu
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                if batch_size_per_gpu <= 0:
         | 
| 570 | 
            +
                    batch_size_per_gpu = 1
         | 
| 571 | 
            +
             | 
| 572 | 
            +
                if samples < 64:
         | 
| 573 | 
            +
                    max_samples = int(samples * 0.25)
         | 
| 574 | 
            +
                else:
         | 
| 575 | 
            +
                    max_samples = 64
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                num_warmup_updates = int(samples * 0.05)
         | 
| 578 | 
            +
                save_per_updates = int(samples * 0.10)
         | 
| 579 | 
            +
                last_per_steps = int(save_per_updates * 5)
         | 
| 580 | 
            +
             | 
| 581 | 
            +
                max_samples = (lambda num: num + 1 if num % 2 != 0 else num)(max_samples)
         | 
| 582 | 
            +
                num_warmup_updates = (lambda num: num + 1 if num % 2 != 0 else num)(num_warmup_updates)
         | 
| 583 | 
            +
                save_per_updates = (lambda num: num + 1 if num % 2 != 0 else num)(save_per_updates)
         | 
| 584 | 
            +
                last_per_steps = (lambda num: num + 1 if num % 2 != 0 else num)(last_per_steps)
         | 
| 585 | 
            +
             | 
| 586 | 
            +
                if finetune:
         | 
| 587 | 
            +
                    learning_rate = 1e-5
         | 
| 588 | 
            +
                else:
         | 
| 589 | 
            +
                    learning_rate = 7.5e-5
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                return batch_size_per_gpu, max_samples, num_warmup_updates, save_per_updates, last_per_steps, samples, learning_rate
         | 
| 592 | 
            +
             | 
| 593 | 
            +
             | 
| 594 | 
            +
            def extract_and_save_ema_model(checkpoint_path: str, new_checkpoint_path: str) -> None:
         | 
| 595 | 
            +
                try:
         | 
| 596 | 
            +
                    checkpoint = torch.load(checkpoint_path)
         | 
| 597 | 
            +
                    print("Original Checkpoint Keys:", checkpoint.keys())
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                    ema_model_state_dict = checkpoint.get("ema_model_state_dict", None)
         | 
| 600 | 
            +
             | 
| 601 | 
            +
                    if ema_model_state_dict is not None:
         | 
| 602 | 
            +
                        new_checkpoint = {"ema_model_state_dict": ema_model_state_dict}
         | 
| 603 | 
            +
                        torch.save(new_checkpoint, new_checkpoint_path)
         | 
| 604 | 
            +
                        return f"New checkpoint saved at: {new_checkpoint_path}"
         | 
| 605 | 
            +
                    else:
         | 
| 606 | 
            +
                        return "No 'ema_model_state_dict' found in the checkpoint."
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                except Exception as e:
         | 
| 609 | 
            +
                    return f"An error occurred: {e}"
         | 
| 610 | 
            +
             | 
| 611 | 
            +
             | 
| 612 | 
            +
            def vocab_check(project_name):
         | 
| 613 | 
            +
                name_project = project_name + "_pinyin"
         | 
| 614 | 
            +
                path_project = os.path.join(path_data, name_project)
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                file_metadata = os.path.join(path_project, "metadata.csv")
         | 
| 617 | 
            +
             | 
| 618 | 
            +
                file_vocab = "data/Emilia_ZH_EN_pinyin/vocab.txt"
         | 
| 619 | 
            +
                if not os.path.isfile(file_vocab):
         | 
| 620 | 
            +
                    return f"the file {file_vocab} not found !"
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                with open(file_vocab, "r", encoding="utf-8") as f:
         | 
| 623 | 
            +
                    data = f.read()
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                vocab = data.split("\n")
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                if not os.path.isfile(file_metadata):
         | 
| 628 | 
            +
                    return f"the file {file_metadata} not found !"
         | 
| 629 | 
            +
             | 
| 630 | 
            +
                with open(file_metadata, "r", encoding="utf-8") as f:
         | 
| 631 | 
            +
                    data = f.read()
         | 
| 632 | 
            +
             | 
| 633 | 
            +
                miss_symbols = []
         | 
| 634 | 
            +
                miss_symbols_keep = {}
         | 
| 635 | 
            +
                for item in data.split("\n"):
         | 
| 636 | 
            +
                    sp = item.split("|")
         | 
| 637 | 
            +
                    if len(sp) != 2:
         | 
| 638 | 
            +
                        continue
         | 
| 639 | 
            +
             | 
| 640 | 
            +
                    text = sp[1].lower().strip()
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                    for t in text:
         | 
| 643 | 
            +
                        if t not in vocab and t not in miss_symbols_keep:
         | 
| 644 | 
            +
                            miss_symbols.append(t)
         | 
| 645 | 
            +
                            miss_symbols_keep[t] = t
         | 
| 646 | 
            +
                if miss_symbols == []:
         | 
| 647 | 
            +
                    info = "You can train using your language !"
         | 
| 648 | 
            +
                else:
         | 
| 649 | 
            +
                    info = f"The following symbols are missing in your language : {len(miss_symbols)}\n\n" + "\n".join(miss_symbols)
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                return info
         | 
| 652 | 
            +
             | 
| 653 | 
            +
             | 
| 654 | 
            +
            def get_random_sample_prepare(project_name):
         | 
| 655 | 
            +
                name_project = project_name + "_pinyin"
         | 
| 656 | 
            +
                path_project = os.path.join(path_data, name_project)
         | 
| 657 | 
            +
                file_arrow = os.path.join(path_project, "raw.arrow")
         | 
| 658 | 
            +
                if not os.path.isfile(file_arrow):
         | 
| 659 | 
            +
                    return "", None
         | 
| 660 | 
            +
                dataset = Dataset_.from_file(file_arrow)
         | 
| 661 | 
            +
                random_sample = dataset.shuffle(seed=random.randint(0, 1000)).select([0])
         | 
| 662 | 
            +
                text = "[" + " , ".join(["' " + t + " '" for t in random_sample["text"][0]]) + "]"
         | 
| 663 | 
            +
                audio_path = random_sample["audio_path"][0]
         | 
| 664 | 
            +
                return text, audio_path
         | 
| 665 | 
            +
             | 
| 666 | 
            +
             | 
| 667 | 
            +
            def get_random_sample_transcribe(project_name):
         | 
| 668 | 
            +
                name_project = project_name + "_pinyin"
         | 
| 669 | 
            +
                path_project = os.path.join(path_data, name_project)
         | 
| 670 | 
            +
                file_metadata = os.path.join(path_project, "metadata.csv")
         | 
| 671 | 
            +
                if not os.path.isfile(file_metadata):
         | 
| 672 | 
            +
                    return "", None
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                data = ""
         | 
| 675 | 
            +
                with open(file_metadata, "r", encoding="utf-8") as f:
         | 
| 676 | 
            +
                    data = f.read()
         | 
| 677 | 
            +
             | 
| 678 | 
            +
                list_data = []
         | 
| 679 | 
            +
                for item in data.split("\n"):
         | 
| 680 | 
            +
                    sp = item.split("|")
         | 
| 681 | 
            +
                    if len(sp) != 2:
         | 
| 682 | 
            +
                        continue
         | 
| 683 | 
            +
                    list_data.append([os.path.join(path_project, "wavs", sp[0] + ".wav"), sp[1]])
         | 
| 684 | 
            +
             | 
| 685 | 
            +
                if list_data == []:
         | 
| 686 | 
            +
                    return "", None
         | 
| 687 | 
            +
             | 
| 688 | 
            +
                random_item = random.choice(list_data)
         | 
| 689 | 
            +
             | 
| 690 | 
            +
                return random_item[1], random_item[0]
         | 
| 691 | 
            +
             | 
| 692 | 
            +
             | 
| 693 | 
            +
            def get_random_sample_infer(project_name):
         | 
| 694 | 
            +
                text, audio = get_random_sample_transcribe(project_name)
         | 
| 695 | 
            +
                return (
         | 
| 696 | 
            +
                    text,
         | 
| 697 | 
            +
                    text,
         | 
| 698 | 
            +
                    audio,
         | 
| 699 | 
            +
                )
         | 
| 700 | 
            +
             | 
| 701 | 
            +
             | 
| 702 | 
            +
            def infer(file_checkpoint, exp_name, ref_text, ref_audio, gen_text, nfe_step):
         | 
| 703 | 
            +
                global last_checkpoint, last_device, tts_api
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                if not os.path.isfile(file_checkpoint):
         | 
| 706 | 
            +
                    return None
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                if training_process is not None:
         | 
| 709 | 
            +
                    device_test = "cpu"
         | 
| 710 | 
            +
                else:
         | 
| 711 | 
            +
                    device_test = None
         | 
| 712 | 
            +
             | 
| 713 | 
            +
                if last_checkpoint != file_checkpoint or last_device != device_test:
         | 
| 714 | 
            +
                    if last_checkpoint != file_checkpoint:
         | 
| 715 | 
            +
                        last_checkpoint = file_checkpoint
         | 
| 716 | 
            +
                    if last_device != device_test:
         | 
| 717 | 
            +
                        last_device = device_test
         | 
| 718 | 
            +
             | 
| 719 | 
            +
                    tts_api = F5TTS(model_type=exp_name, ckpt_file=file_checkpoint, device=device_test)
         | 
| 720 | 
            +
             | 
| 721 | 
            +
                    print("update", device_test, file_checkpoint)
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
         | 
| 724 | 
            +
                    tts_api.infer(gen_text=gen_text, ref_text=ref_text, ref_file=ref_audio, nfe_step=nfe_step, file_wave=f.name)
         | 
| 725 | 
            +
                    return f.name
         | 
| 726 | 
            +
             | 
| 727 | 
            +
             | 
| 728 | 
            +
            with gr.Blocks() as app:
         | 
| 729 | 
            +
                with gr.Row():
         | 
| 730 | 
            +
                    project_name = gr.Textbox(label="project name", value="my_speak")
         | 
| 731 | 
            +
                    bt_create = gr.Button("create new project")
         | 
| 732 | 
            +
             | 
| 733 | 
            +
                bt_create.click(fn=create_data_project, inputs=[project_name])
         | 
| 734 | 
            +
             | 
| 735 | 
            +
                with gr.Tabs():
         | 
| 736 | 
            +
                    with gr.TabItem("transcribe Data"):
         | 
| 737 | 
            +
                        ch_manual = gr.Checkbox(label="user", value=False)
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                        mark_info_transcribe = gr.Markdown(
         | 
| 740 | 
            +
                            """```plaintext    
         | 
| 741 | 
            +
                 Place your 'wavs' folder and 'metadata.csv' file in the {your_project_name}' directory. 
         | 
| 742 | 
            +
                             
         | 
| 743 | 
            +
                 my_speak/
         | 
| 744 | 
            +
                 β
         | 
| 745 | 
            +
                 βββ dataset/
         | 
| 746 | 
            +
                     βββ audio1.wav
         | 
| 747 | 
            +
                     βββ audio2.wav
         | 
| 748 | 
            +
                     ...
         | 
| 749 | 
            +
                 ```""",
         | 
| 750 | 
            +
                            visible=False,
         | 
| 751 | 
            +
                        )
         | 
| 752 | 
            +
             | 
| 753 | 
            +
                        audio_speaker = gr.File(label="voice", type="filepath", file_count="multiple")
         | 
| 754 | 
            +
                        txt_lang = gr.Text(label="Language", value="english")
         | 
| 755 | 
            +
                        bt_transcribe = bt_create = gr.Button("transcribe")
         | 
| 756 | 
            +
                        txt_info_transcribe = gr.Text(label="info", value="")
         | 
| 757 | 
            +
                        bt_transcribe.click(
         | 
| 758 | 
            +
                            fn=transcribe_all,
         | 
| 759 | 
            +
                            inputs=[project_name, audio_speaker, txt_lang, ch_manual],
         | 
| 760 | 
            +
                            outputs=[txt_info_transcribe],
         | 
| 761 | 
            +
                        )
         | 
| 762 | 
            +
                        ch_manual.change(fn=check_user, inputs=[ch_manual], outputs=[audio_speaker, mark_info_transcribe])
         | 
| 763 | 
            +
             | 
| 764 | 
            +
                        random_sample_transcribe = gr.Button("random sample")
         | 
| 765 | 
            +
             | 
| 766 | 
            +
                        with gr.Row():
         | 
| 767 | 
            +
                            random_text_transcribe = gr.Text(label="Text")
         | 
| 768 | 
            +
                            random_audio_transcribe = gr.Audio(label="Audio", type="filepath")
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                        random_sample_transcribe.click(
         | 
| 771 | 
            +
                            fn=get_random_sample_transcribe,
         | 
| 772 | 
            +
                            inputs=[project_name],
         | 
| 773 | 
            +
                            outputs=[random_text_transcribe, random_audio_transcribe],
         | 
| 774 | 
            +
                        )
         | 
| 775 | 
            +
             | 
| 776 | 
            +
                    with gr.TabItem("prepare Data"):
         | 
| 777 | 
            +
                        gr.Markdown(
         | 
| 778 | 
            +
                            """```plaintext    
         | 
| 779 | 
            +
                 place all your wavs folder and your metadata.csv file in {your name project}                                 
         | 
| 780 | 
            +
                 my_speak/
         | 
| 781 | 
            +
                 β
         | 
| 782 | 
            +
                 βββ wavs/
         | 
| 783 | 
            +
                 β   βββ audio1.wav
         | 
| 784 | 
            +
                 β   βββ audio2.wav
         | 
| 785 | 
            +
                 |   ...
         | 
| 786 | 
            +
                 β
         | 
| 787 | 
            +
                 βββ metadata.csv
         | 
| 788 | 
            +
                  
         | 
| 789 | 
            +
                 file format metadata.csv
         | 
| 790 | 
            +
             | 
| 791 | 
            +
                 audio1|text1
         | 
| 792 | 
            +
                 audio2|text1
         | 
| 793 | 
            +
                 ...
         | 
| 794 | 
            +
             | 
| 795 | 
            +
                 ```"""
         | 
| 796 | 
            +
                        )
         | 
| 797 | 
            +
             | 
| 798 | 
            +
                        bt_prepare = bt_create = gr.Button("prepare")
         | 
| 799 | 
            +
                        txt_info_prepare = gr.Text(label="info", value="")
         | 
| 800 | 
            +
                        bt_prepare.click(fn=create_metadata, inputs=[project_name], outputs=[txt_info_prepare])
         | 
| 801 | 
            +
             | 
| 802 | 
            +
                        random_sample_prepare = gr.Button("random sample")
         | 
| 803 | 
            +
             | 
| 804 | 
            +
                        with gr.Row():
         | 
| 805 | 
            +
                            random_text_prepare = gr.Text(label="Pinyin")
         | 
| 806 | 
            +
                            random_audio_prepare = gr.Audio(label="Audio", type="filepath")
         | 
| 807 | 
            +
             | 
| 808 | 
            +
                        random_sample_prepare.click(
         | 
| 809 | 
            +
                            fn=get_random_sample_prepare, inputs=[project_name], outputs=[random_text_prepare, random_audio_prepare]
         | 
| 810 | 
            +
                        )
         | 
| 811 | 
            +
             | 
| 812 | 
            +
                    with gr.TabItem("train Data"):
         | 
| 813 | 
            +
                        with gr.Row():
         | 
| 814 | 
            +
                            bt_calculate = bt_create = gr.Button("Auto Settings")
         | 
| 815 | 
            +
                            ch_finetune = bt_create = gr.Checkbox(label="finetune", value=True)
         | 
| 816 | 
            +
                            lb_samples = gr.Label(label="samples")
         | 
| 817 | 
            +
                            batch_size_type = gr.Radio(label="Batch Size Type", choices=["frame", "sample"], value="frame")
         | 
| 818 | 
            +
             | 
| 819 | 
            +
                        with gr.Row():
         | 
| 820 | 
            +
                            exp_name = gr.Radio(label="Model", choices=["F5TTS_Base", "E2TTS_Base"], value="F5TTS_Base")
         | 
| 821 | 
            +
                            learning_rate = gr.Number(label="Learning Rate", value=1e-5, step=1e-5)
         | 
| 822 | 
            +
             | 
| 823 | 
            +
                        with gr.Row():
         | 
| 824 | 
            +
                            batch_size_per_gpu = gr.Number(label="Batch Size per GPU", value=1000)
         | 
| 825 | 
            +
                            max_samples = gr.Number(label="Max Samples", value=64)
         | 
| 826 | 
            +
             | 
| 827 | 
            +
                        with gr.Row():
         | 
| 828 | 
            +
                            grad_accumulation_steps = gr.Number(label="Gradient Accumulation Steps", value=1)
         | 
| 829 | 
            +
                            max_grad_norm = gr.Number(label="Max Gradient Norm", value=1.0)
         | 
| 830 | 
            +
             | 
| 831 | 
            +
                        with gr.Row():
         | 
| 832 | 
            +
                            epochs = gr.Number(label="Epochs", value=10)
         | 
| 833 | 
            +
                            num_warmup_updates = gr.Number(label="Warmup Updates", value=5)
         | 
| 834 | 
            +
             | 
| 835 | 
            +
                        with gr.Row():
         | 
| 836 | 
            +
                            save_per_updates = gr.Number(label="Save per Updates", value=10)
         | 
| 837 | 
            +
                            last_per_steps = gr.Number(label="Last per Steps", value=50)
         | 
| 838 | 
            +
             | 
| 839 | 
            +
                        with gr.Row():
         | 
| 840 | 
            +
                            start_button = gr.Button("Start Training")
         | 
| 841 | 
            +
                            stop_button = gr.Button("Stop Training", interactive=False)
         | 
| 842 | 
            +
             | 
| 843 | 
            +
                        txt_info_train = gr.Text(label="info", value="")
         | 
| 844 | 
            +
                        start_button.click(
         | 
| 845 | 
            +
                            fn=start_training,
         | 
| 846 | 
            +
                            inputs=[
         | 
| 847 | 
            +
                                project_name,
         | 
| 848 | 
            +
                                exp_name,
         | 
| 849 | 
            +
                                learning_rate,
         | 
| 850 | 
            +
                                batch_size_per_gpu,
         | 
| 851 | 
            +
                                batch_size_type,
         | 
| 852 | 
            +
                                max_samples,
         | 
| 853 | 
            +
                                grad_accumulation_steps,
         | 
| 854 | 
            +
                                max_grad_norm,
         | 
| 855 | 
            +
                                epochs,
         | 
| 856 | 
            +
                                num_warmup_updates,
         | 
| 857 | 
            +
                                save_per_updates,
         | 
| 858 | 
            +
                                last_per_steps,
         | 
| 859 | 
            +
                                ch_finetune,
         | 
| 860 | 
            +
                            ],
         | 
| 861 | 
            +
                            outputs=[txt_info_train, start_button, stop_button],
         | 
| 862 | 
            +
                        )
         | 
| 863 | 
            +
                        stop_button.click(fn=stop_training, outputs=[txt_info_train, start_button, stop_button])
         | 
| 864 | 
            +
                        bt_calculate.click(
         | 
| 865 | 
            +
                            fn=calculate_train,
         | 
| 866 | 
            +
                            inputs=[
         | 
| 867 | 
            +
                                project_name,
         | 
| 868 | 
            +
                                batch_size_type,
         | 
| 869 | 
            +
                                max_samples,
         | 
| 870 | 
            +
                                learning_rate,
         | 
| 871 | 
            +
                                num_warmup_updates,
         | 
| 872 | 
            +
                                save_per_updates,
         | 
| 873 | 
            +
                                last_per_steps,
         | 
| 874 | 
            +
                                ch_finetune,
         | 
| 875 | 
            +
                            ],
         | 
| 876 | 
            +
                            outputs=[
         | 
| 877 | 
            +
                                batch_size_per_gpu,
         | 
| 878 | 
            +
                                max_samples,
         | 
| 879 | 
            +
                                num_warmup_updates,
         | 
| 880 | 
            +
                                save_per_updates,
         | 
| 881 | 
            +
                                last_per_steps,
         | 
| 882 | 
            +
                                lb_samples,
         | 
| 883 | 
            +
                                learning_rate,
         | 
| 884 | 
            +
                            ],
         | 
| 885 | 
            +
                        )
         | 
| 886 | 
            +
             | 
| 887 | 
            +
                    with gr.TabItem("reduse checkpoint"):
         | 
| 888 | 
            +
                        txt_path_checkpoint = gr.Text(label="path checkpoint :")
         | 
| 889 | 
            +
                        txt_path_checkpoint_small = gr.Text(label="path output :")
         | 
| 890 | 
            +
                        txt_info_reduse = gr.Text(label="info", value="")
         | 
| 891 | 
            +
                        reduse_button = gr.Button("reduse")
         | 
| 892 | 
            +
                        reduse_button.click(
         | 
| 893 | 
            +
                            fn=extract_and_save_ema_model,
         | 
| 894 | 
            +
                            inputs=[txt_path_checkpoint, txt_path_checkpoint_small],
         | 
| 895 | 
            +
                            outputs=[txt_info_reduse],
         | 
| 896 | 
            +
                        )
         | 
| 897 | 
            +
             | 
| 898 | 
            +
                    with gr.TabItem("vocab check experiment"):
         | 
| 899 | 
            +
                        check_button = gr.Button("check vocab")
         | 
| 900 | 
            +
                        txt_info_check = gr.Text(label="info", value="")
         | 
| 901 | 
            +
                        check_button.click(fn=vocab_check, inputs=[project_name], outputs=[txt_info_check])
         | 
| 902 | 
            +
             | 
| 903 | 
            +
                    with gr.TabItem("test model"):
         | 
| 904 | 
            +
                        exp_name = gr.Radio(label="Model", choices=["F5-TTS", "E2-TTS"], value="F5-TTS")
         | 
| 905 | 
            +
                        nfe_step = gr.Number(label="n_step", value=32)
         | 
| 906 | 
            +
                        file_checkpoint_pt = gr.Textbox(label="Checkpoint", value="")
         | 
| 907 | 
            +
             | 
| 908 | 
            +
                        random_sample_infer = gr.Button("random sample")
         | 
| 909 | 
            +
             | 
| 910 | 
            +
                        ref_text = gr.Textbox(label="ref text")
         | 
| 911 | 
            +
                        ref_audio = gr.Audio(label="audio ref", type="filepath")
         | 
| 912 | 
            +
                        gen_text = gr.Textbox(label="gen text")
         | 
| 913 | 
            +
                        random_sample_infer.click(
         | 
| 914 | 
            +
                            fn=get_random_sample_infer, inputs=[project_name], outputs=[ref_text, gen_text, ref_audio]
         | 
| 915 | 
            +
                        )
         | 
| 916 | 
            +
                        check_button_infer = gr.Button("infer")
         | 
| 917 | 
            +
                        gen_audio = gr.Audio(label="audio gen", type="filepath")
         | 
| 918 | 
            +
             | 
| 919 | 
            +
                        check_button_infer.click(
         | 
| 920 | 
            +
                            fn=infer,
         | 
| 921 | 
            +
                            inputs=[file_checkpoint_pt, exp_name, ref_text, ref_audio, gen_text, nfe_step],
         | 
| 922 | 
            +
                            outputs=[gen_audio],
         | 
| 923 | 
            +
                        )
         | 
| 924 | 
            +
             | 
| 925 | 
            +
             | 
| 926 | 
            +
            @click.command()
         | 
| 927 | 
            +
            @click.option("--port", "-p", default=None, type=int, help="Port to run the app on")
         | 
| 928 | 
            +
            @click.option("--host", "-H", default=None, help="Host to run the app on")
         | 
| 929 | 
            +
            @click.option(
         | 
| 930 | 
            +
                "--share",
         | 
| 931 | 
            +
                "-s",
         | 
| 932 | 
            +
                default=False,
         | 
| 933 | 
            +
                is_flag=True,
         | 
| 934 | 
            +
                help="Share the app via Gradio share link",
         | 
| 935 | 
            +
            )
         | 
| 936 | 
            +
            @click.option("--api", "-a", default=True, is_flag=True, help="Allow API access")
         | 
| 937 | 
            +
            def main(port, host, share, api):
         | 
| 938 | 
            +
                global app
         | 
| 939 | 
            +
                print("Starting app...")
         | 
| 940 | 
            +
                app.queue(api_open=api).launch(server_name=host, server_port=port, share=share, show_api=api)
         | 
| 941 | 
            +
             | 
| 942 | 
            +
             | 
| 943 | 
            +
            if __name__ == "__main__":
         | 
| 944 | 
            +
                main()
         | 
    	
        src/f5_tts/{train.py β train/train.py}
    RENAMED
    
    | 
            File without changes
         | 

