Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		zetavg
		
	commited on
		
		
					make the training process async
Browse files- README.md +2 -2
 - app.py +4 -0
 - config.yaml.sample +2 -0
 - llama_lora/config.py +7 -1
 - llama_lora/globals.py +19 -0
 - llama_lora/models.py +8 -0
 - llama_lora/ui/finetune/finetune_ui.py +23 -12
 - llama_lora/ui/finetune/script.js +26 -11
 - llama_lora/ui/finetune/style.css +114 -4
 - llama_lora/ui/finetune/training.py +275 -212
 - llama_lora/ui/inference_ui.py +2 -2
 - llama_lora/ui/main_page.py +55 -1
 - llama_lora/ui/tokenizer_ui.py +1 -1
 - llama_lora/ui/trainer_callback.py +104 -0
 - llama_lora/utils/eta_predictor.py +54 -0
 
    	
        README.md
    CHANGED
    
    | 
         @@ -70,7 +70,7 @@ setup: | 
     | 
|
| 70 | 
         
             
            # Start the app.
         
     | 
| 71 | 
         
             
            run: |
         
     | 
| 72 | 
         
             
              echo 'Starting...'
         
     | 
| 73 | 
         
            -
              python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key="$([ -f /data/secrets/wandb_api_key ] && cat /data/secrets/wandb_api_key | tr -d '\n')" --base_model=decapoda-research/llama-7b-hf --base_model_choices='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j,databricks/dolly-v2-7b --share
         
     | 
| 74 | 
         
             
            ```
         
     | 
| 75 | 
         | 
| 76 | 
         
             
            Then launch a cluster to run the task:
         
     | 
| 
         @@ -100,7 +100,7 @@ When you are done, run `sky stop <cluster_name>` to stop the cluster. To termina 
     | 
|
| 100 | 
         | 
| 101 | 
         
             
            ```bash
         
     | 
| 102 | 
         
             
            pip install -r requirements.lock.txt
         
     | 
| 103 | 
         
            -
            python app.py --data_dir='./data' --base_model='decapoda-research/llama-7b-hf' --share
         
     | 
| 104 | 
         
             
            ```
         
     | 
| 105 | 
         | 
| 106 | 
         
             
            You will see the local and public URLs of the app in the terminal. Open the URL in your browser to use the app.
         
     | 
| 
         | 
|
| 70 | 
         
             
            # Start the app.
         
     | 
| 71 | 
         
             
            run: |
         
     | 
| 72 | 
         
             
              echo 'Starting...'
         
     | 
| 73 | 
         
            +
              python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key="$([ -f /data/secrets/wandb_api_key ] && cat /data/secrets/wandb_api_key | tr -d '\n')" --timezone='Atlantic/Reykjavik' --base_model=decapoda-research/llama-7b-hf --base_model_choices='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j,databricks/dolly-v2-7b --share
         
     | 
| 74 | 
         
             
            ```
         
     | 
| 75 | 
         | 
| 76 | 
         
             
            Then launch a cluster to run the task:
         
     | 
| 
         | 
|
| 100 | 
         | 
| 101 | 
         
             
            ```bash
         
     | 
| 102 | 
         
             
            pip install -r requirements.lock.txt
         
     | 
| 103 | 
         
            +
            python app.py --data_dir='./data' --base_model='decapoda-research/llama-7b-hf' --timezone='Atlantic/Reykjavik' --share
         
     | 
| 104 | 
         
             
            ```
         
     | 
| 105 | 
         | 
| 106 | 
         
             
            You will see the local and public URLs of the app in the terminal. Open the URL in your browser to use the app.
         
     | 
    	
        app.py
    CHANGED
    
    | 
         @@ -28,6 +28,7 @@ def main( 
     | 
|
| 28 | 
         
             
                ui_dev_mode: Union[bool, None] = None,
         
     | 
| 29 | 
         
             
                wandb_api_key: Union[str, None] = None,
         
     | 
| 30 | 
         
             
                wandb_project: Union[str, None] = None,
         
     | 
| 
         | 
|
| 31 | 
         
             
            ):
         
     | 
| 32 | 
         
             
                '''
         
     | 
| 33 | 
         
             
                Start the LLaMA-LoRA Tuner UI.
         
     | 
| 
         @@ -76,6 +77,9 @@ def main( 
     | 
|
| 76 | 
         
             
                if wandb_project is not None:
         
     | 
| 77 | 
         
             
                    Config.default_wandb_project = wandb_project
         
     | 
| 78 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 79 | 
         
             
                if ui_dev_mode is not None:
         
     | 
| 80 | 
         
             
                    Config.ui_dev_mode = ui_dev_mode
         
     | 
| 81 | 
         | 
| 
         | 
|
| 28 | 
         
             
                ui_dev_mode: Union[bool, None] = None,
         
     | 
| 29 | 
         
             
                wandb_api_key: Union[str, None] = None,
         
     | 
| 30 | 
         
             
                wandb_project: Union[str, None] = None,
         
     | 
| 31 | 
         
            +
                timezone: Union[str, None] = None,
         
     | 
| 32 | 
         
             
            ):
         
     | 
| 33 | 
         
             
                '''
         
     | 
| 34 | 
         
             
                Start the LLaMA-LoRA Tuner UI.
         
     | 
| 
         | 
|
| 77 | 
         
             
                if wandb_project is not None:
         
     | 
| 78 | 
         
             
                    Config.default_wandb_project = wandb_project
         
     | 
| 79 | 
         | 
| 80 | 
         
            +
                if timezone is not None:
         
     | 
| 81 | 
         
            +
                    Config.timezone = timezone
         
     | 
| 82 | 
         
            +
             
     | 
| 83 | 
         
             
                if ui_dev_mode is not None:
         
     | 
| 84 | 
         
             
                    Config.ui_dev_mode = ui_dev_mode
         
     | 
| 85 | 
         | 
    	
        config.yaml.sample
    CHANGED
    
    | 
         @@ -9,6 +9,8 @@ base_model_choices: 
     | 
|
| 9 | 
         
             
            load_8bit: false
         
     | 
| 10 | 
         
             
            trust_remote_code: false
         
     | 
| 11 | 
         | 
| 
         | 
|
| 
         | 
|
| 12 | 
         
             
            # UI Customization
         
     | 
| 13 | 
         
             
            # ui_title: LLM Tuner
         
     | 
| 14 | 
         
             
            # ui_emoji: 🦙🎛️
         
     | 
| 
         | 
|
| 9 | 
         
             
            load_8bit: false
         
     | 
| 10 | 
         
             
            trust_remote_code: false
         
     | 
| 11 | 
         | 
| 12 | 
         
            +
            # timezone: Atlantic/Reykjavik
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
             
            # UI Customization
         
     | 
| 15 | 
         
             
            # ui_title: LLM Tuner
         
     | 
| 16 | 
         
             
            # ui_emoji: 🦙🎛️
         
     | 
    	
        llama_lora/config.py
    CHANGED
    
    | 
         @@ -1,5 +1,6 @@ 
     | 
|
| 1 | 
         
             
            import os
         
     | 
| 2 | 
         
            -
             
     | 
| 
         | 
|
| 3 | 
         | 
| 4 | 
         | 
| 5 | 
         
             
            class Config:
         
     | 
| 
         @@ -15,6 +16,8 @@ class Config: 
     | 
|
| 15 | 
         | 
| 16 | 
         
             
                trust_remote_code: bool = False
         
     | 
| 17 | 
         | 
| 
         | 
|
| 
         | 
|
| 18 | 
         
             
                # WandB
         
     | 
| 19 | 
         
             
                enable_wandb: Union[bool, None] = False
         
     | 
| 20 | 
         
             
                wandb_api_key: Union[str, None] = None
         
     | 
| 
         @@ -37,6 +40,9 @@ def process_config(): 
     | 
|
| 37 | 
         
             
                    base_model_choices = [name.strip() for name in base_model_choices]
         
     | 
| 38 | 
         
             
                    Config.base_model_choices = base_model_choices
         
     | 
| 39 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 40 | 
         
             
                if Config.default_base_model_name not in Config.base_model_choices:
         
     | 
| 41 | 
         
             
                    Config.base_model_choices = [Config.default_base_model_name] + Config.base_model_choices
         
     | 
| 42 | 
         | 
| 
         | 
|
| 1 | 
         
             
            import os
         
     | 
| 2 | 
         
            +
            import pytz
         
     | 
| 3 | 
         
            +
            from typing import List, Union, Any
         
     | 
| 4 | 
         | 
| 5 | 
         | 
| 6 | 
         
             
            class Config:
         
     | 
| 
         | 
|
| 16 | 
         | 
| 17 | 
         
             
                trust_remote_code: bool = False
         
     | 
| 18 | 
         | 
| 19 | 
         
            +
                timezone: Any = pytz.UTC
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
             
                # WandB
         
     | 
| 22 | 
         
             
                enable_wandb: Union[bool, None] = False
         
     | 
| 23 | 
         
             
                wandb_api_key: Union[str, None] = None
         
     | 
| 
         | 
|
| 40 | 
         
             
                    base_model_choices = [name.strip() for name in base_model_choices]
         
     | 
| 41 | 
         
             
                    Config.base_model_choices = base_model_choices
         
     | 
| 42 | 
         | 
| 43 | 
         
            +
                if isinstance(Config.timezone, str):
         
     | 
| 44 | 
         
            +
                    Config.timezone = pytz.timezone(Config.timezone)
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
             
                if Config.default_base_model_name not in Config.base_model_choices:
         
     | 
| 47 | 
         
             
                    Config.base_model_choices = [Config.default_base_model_name] + Config.base_model_choices
         
     | 
| 48 | 
         | 
    	
        llama_lora/globals.py
    CHANGED
    
    | 
         @@ -12,6 +12,7 @@ import nvidia_smi 
     | 
|
| 12 | 
         
             
            from .dynamic_import import dynamic_import
         
     | 
| 13 | 
         
             
            from .config import Config
         
     | 
| 14 | 
         
             
            from .utils.lru_cache import LRUCache
         
     | 
| 
         | 
|
| 15 | 
         | 
| 16 | 
         | 
| 17 | 
         
             
            class Global:
         
     | 
| 
         @@ -31,6 +32,24 @@ class Global: 
     | 
|
| 31 | 
         
             
                # Training Control
         
     | 
| 32 | 
         
             
                should_stop_training: bool = False
         
     | 
| 33 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 34 | 
         
             
                # Generation Control
         
     | 
| 35 | 
         
             
                should_stop_generating: bool = False
         
     | 
| 36 | 
         
             
                generation_force_stopped_at: Union[float, None] = None
         
     | 
| 
         | 
|
| 12 | 
         
             
            from .dynamic_import import dynamic_import
         
     | 
| 13 | 
         
             
            from .config import Config
         
     | 
| 14 | 
         
             
            from .utils.lru_cache import LRUCache
         
     | 
| 15 | 
         
            +
            from .utils.eta_predictor import ETAPredictor
         
     | 
| 16 | 
         | 
| 17 | 
         | 
| 18 | 
         
             
            class Global:
         
     | 
| 
         | 
|
| 32 | 
         
             
                # Training Control
         
     | 
| 33 | 
         
             
                should_stop_training: bool = False
         
     | 
| 34 | 
         | 
| 35 | 
         
            +
                # Training Status
         
     | 
| 36 | 
         
            +
                is_train_starting: bool = False
         
     | 
| 37 | 
         
            +
                is_training: bool = False
         
     | 
| 38 | 
         
            +
                train_started_at: float = 0.0
         
     | 
| 39 | 
         
            +
                training_error_message: Union[str, None] = None
         
     | 
| 40 | 
         
            +
                training_error_detail: Union[str, None] = None
         
     | 
| 41 | 
         
            +
                training_total_epochs: int = 0
         
     | 
| 42 | 
         
            +
                training_current_epoch: float = 0.0
         
     | 
| 43 | 
         
            +
                training_total_steps: int = 0
         
     | 
| 44 | 
         
            +
                training_current_step: int = 0
         
     | 
| 45 | 
         
            +
                training_progress: float = 0.0
         
     | 
| 46 | 
         
            +
                training_log_history: List[Any] = []
         
     | 
| 47 | 
         
            +
                training_status_text: str = ""
         
     | 
| 48 | 
         
            +
                training_eta_predictor = ETAPredictor()
         
     | 
| 49 | 
         
            +
                training_eta: Union[int, None] = None
         
     | 
| 50 | 
         
            +
                train_output: Union[None, Any] = None
         
     | 
| 51 | 
         
            +
                train_output_str: Union[None, str] = None
         
     | 
| 52 | 
         
            +
             
     | 
| 53 | 
         
             
                # Generation Control
         
     | 
| 54 | 
         
             
                should_stop_generating: bool = False
         
     | 
| 55 | 
         
             
                generation_force_stopped_at: Union[float, None] = None
         
     | 
    	
        llama_lora/models.py
    CHANGED
    
    | 
         @@ -26,6 +26,8 @@ def get_peft_model_class(): 
     | 
|
| 26 | 
         
             
            def get_new_base_model(base_model_name):
         
     | 
| 27 | 
         
             
                if Config.ui_dev_mode:
         
     | 
| 28 | 
         
             
                    return
         
     | 
| 
         | 
|
| 
         | 
|
| 29 | 
         | 
| 30 | 
         
             
                if Global.new_base_model_that_is_ready_to_be_used:
         
     | 
| 31 | 
         
             
                    if Global.name_of_new_base_model_that_is_ready_to_be_used == base_model_name:
         
     | 
| 
         @@ -121,6 +123,9 @@ def get_tokenizer(base_model_name): 
     | 
|
| 121 | 
         
             
                if Config.ui_dev_mode:
         
     | 
| 122 | 
         
             
                    return
         
     | 
| 123 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 124 | 
         
             
                loaded_tokenizer = Global.loaded_tokenizers.get(base_model_name)
         
     | 
| 125 | 
         
             
                if loaded_tokenizer:
         
     | 
| 126 | 
         
             
                    return loaded_tokenizer
         
     | 
| 
         @@ -150,6 +155,9 @@ def get_model( 
     | 
|
| 150 | 
         
             
                if Config.ui_dev_mode:
         
     | 
| 151 | 
         
             
                    return
         
     | 
| 152 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 153 | 
         
             
                if peft_model_name == "None":
         
     | 
| 154 | 
         
             
                    peft_model_name = None
         
     | 
| 155 | 
         | 
| 
         | 
|
| 26 | 
         
             
            def get_new_base_model(base_model_name):
         
     | 
| 27 | 
         
             
                if Config.ui_dev_mode:
         
     | 
| 28 | 
         
             
                    return
         
     | 
| 29 | 
         
            +
                if Global.is_train_starting or Global.is_training:
         
     | 
| 30 | 
         
            +
                    raise Exception("Cannot load new base model while training.")
         
     | 
| 31 | 
         | 
| 32 | 
         
             
                if Global.new_base_model_that_is_ready_to_be_used:
         
     | 
| 33 | 
         
             
                    if Global.name_of_new_base_model_that_is_ready_to_be_used == base_model_name:
         
     | 
| 
         | 
|
| 123 | 
         
             
                if Config.ui_dev_mode:
         
     | 
| 124 | 
         
             
                    return
         
     | 
| 125 | 
         | 
| 126 | 
         
            +
                if Global.is_train_starting or Global.is_training:
         
     | 
| 127 | 
         
            +
                    raise Exception("Cannot load new base model while training.")
         
     | 
| 128 | 
         
            +
             
     | 
| 129 | 
         
             
                loaded_tokenizer = Global.loaded_tokenizers.get(base_model_name)
         
     | 
| 130 | 
         
             
                if loaded_tokenizer:
         
     | 
| 131 | 
         
             
                    return loaded_tokenizer
         
     | 
| 
         | 
|
| 155 | 
         
             
                if Config.ui_dev_mode:
         
     | 
| 156 | 
         
             
                    return
         
     | 
| 157 | 
         | 
| 158 | 
         
            +
                if Global.is_train_starting or Global.is_training:
         
     | 
| 159 | 
         
            +
                    raise Exception("Cannot load new base model while training.")
         
     | 
| 160 | 
         
            +
             
     | 
| 161 | 
         
             
                if peft_model_name == "None":
         
     | 
| 162 | 
         
             
                    peft_model_name = None
         
     | 
| 163 | 
         | 
    	
        llama_lora/ui/finetune/finetune_ui.py
    CHANGED
    
    | 
         @@ -27,7 +27,8 @@ from .previewing import ( 
     | 
|
| 27 | 
         
             
                refresh_dataset_items_count,
         
     | 
| 28 | 
         
             
            )
         
     | 
| 29 | 
         
             
            from .training import (
         
     | 
| 30 | 
         
            -
                do_train
         
     | 
| 
         | 
|
| 31 | 
         
             
            )
         
     | 
| 32 | 
         | 
| 33 | 
         
             
            register_css_style('finetune', relative_read_file(__file__, "style.css"))
         
     | 
| 
         @@ -770,19 +771,22 @@ def finetune_ui(): 
     | 
|
| 770 | 
         
             
                        )
         
     | 
| 771 | 
         
             
                    )
         
     | 
| 772 | 
         | 
| 773 | 
         
            -
                     
     | 
| 774 | 
         
             
                        "Training results will be shown here.",
         
     | 
| 775 | 
         
             
                        label="Train Output",
         
     | 
| 776 | 
         
             
                        elem_id="finetune_training_status")
         
     | 
| 777 | 
         | 
| 778 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 779 | 
         
             
                        fn=do_train,
         
     | 
| 780 | 
         
             
                        inputs=(dataset_inputs + finetune_args + [
         
     | 
| 781 | 
         
             
                            model_name,
         
     | 
| 782 | 
         
             
                            continue_from_model,
         
     | 
| 783 | 
         
             
                            continue_from_checkpoint,
         
     | 
| 784 | 
         
             
                        ]),
         
     | 
| 785 | 
         
            -
                        outputs= 
     | 
| 786 | 
         
             
                    )
         
     | 
| 787 | 
         | 
| 788 | 
         
             
                    # controlled by JS, shows the confirm_abort_button
         
     | 
| 
         @@ -790,13 +794,20 @@ def finetune_ui(): 
     | 
|
| 790 | 
         
             
                    confirm_abort_button.click(
         
     | 
| 791 | 
         
             
                        fn=do_abort_training,
         
     | 
| 792 | 
         
             
                        inputs=None, outputs=None,
         
     | 
| 793 | 
         
            -
                        cancels=[ 
     | 
| 794 | 
         
            -
             
     | 
| 795 | 
         
            -
                    stop_timeoutable_btn = gr.Button(
         
     | 
| 796 | 
         
            -
                        "stop not-responding elements",
         
     | 
| 797 | 
         
            -
                        elem_id="inference_stop_timeoutable_btn",
         
     | 
| 798 | 
         
            -
                        elem_classes="foot_stop_timeoutable_btn")
         
     | 
| 799 | 
         
            -
                    stop_timeoutable_btn.click(
         
     | 
| 800 | 
         
            -
                        fn=None, inputs=None, outputs=None, cancels=things_that_might_timeout)
         
     | 
| 801 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 802 | 
         
             
                finetune_ui_blocks.load(_js=relative_read_file(__file__, "script.js"))
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 27 | 
         
             
                refresh_dataset_items_count,
         
     | 
| 28 | 
         
             
            )
         
     | 
| 29 | 
         
             
            from .training import (
         
     | 
| 30 | 
         
            +
                do_train,
         
     | 
| 31 | 
         
            +
                render_training_status
         
     | 
| 32 | 
         
             
            )
         
     | 
| 33 | 
         | 
| 34 | 
         
             
            register_css_style('finetune', relative_read_file(__file__, "style.css"))
         
     | 
| 
         | 
|
| 771 | 
         
             
                        )
         
     | 
| 772 | 
         
             
                    )
         
     | 
| 773 | 
         | 
| 774 | 
         
            +
                    train_status = gr.HTML(
         
     | 
| 775 | 
         
             
                        "Training results will be shown here.",
         
     | 
| 776 | 
         
             
                        label="Train Output",
         
     | 
| 777 | 
         
             
                        elem_id="finetune_training_status")
         
     | 
| 778 | 
         | 
| 779 | 
         
            +
                    training_indicator = gr.HTML(
         
     | 
| 780 | 
         
            +
                        "training_indicator", visible=False, elem_id="finetune_training_indicator")
         
     | 
| 781 | 
         
            +
             
     | 
| 782 | 
         
            +
                    train_start = train_btn.click(
         
     | 
| 783 | 
         
             
                        fn=do_train,
         
     | 
| 784 | 
         
             
                        inputs=(dataset_inputs + finetune_args + [
         
     | 
| 785 | 
         
             
                            model_name,
         
     | 
| 786 | 
         
             
                            continue_from_model,
         
     | 
| 787 | 
         
             
                            continue_from_checkpoint,
         
     | 
| 788 | 
         
             
                        ]),
         
     | 
| 789 | 
         
            +
                        outputs=[train_status, training_indicator]
         
     | 
| 790 | 
         
             
                    )
         
     | 
| 791 | 
         | 
| 792 | 
         
             
                    # controlled by JS, shows the confirm_abort_button
         
     | 
| 
         | 
|
| 794 | 
         
             
                    confirm_abort_button.click(
         
     | 
| 795 | 
         
             
                        fn=do_abort_training,
         
     | 
| 796 | 
         
             
                        inputs=None, outputs=None,
         
     | 
| 797 | 
         
            +
                        cancels=[train_start])
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 798 | 
         | 
| 799 | 
         
            +
                training_status_updates = finetune_ui_blocks.load(
         
     | 
| 800 | 
         
            +
                    fn=render_training_status,
         
     | 
| 801 | 
         
            +
                    inputs=None,
         
     | 
| 802 | 
         
            +
                    outputs=[train_status, training_indicator],
         
     | 
| 803 | 
         
            +
                    every=0.1
         
     | 
| 804 | 
         
            +
                )
         
     | 
| 805 | 
         
             
                finetune_ui_blocks.load(_js=relative_read_file(__file__, "script.js"))
         
     | 
| 806 | 
         
            +
             
     | 
| 807 | 
         
            +
                # things_that_might_timeout.append(training_status_updates)
         
     | 
| 808 | 
         
            +
                stop_timeoutable_btn = gr.Button(
         
     | 
| 809 | 
         
            +
                    "stop not-responding elements",
         
     | 
| 810 | 
         
            +
                    elem_id="inference_stop_timeoutable_btn",
         
     | 
| 811 | 
         
            +
                    elem_classes="foot_stop_timeoutable_btn")
         
     | 
| 812 | 
         
            +
                stop_timeoutable_btn.click(
         
     | 
| 813 | 
         
            +
                    fn=None, inputs=None, outputs=None, cancels=things_that_might_timeout)
         
     | 
    	
        llama_lora/ui/finetune/script.js
    CHANGED
    
    | 
         @@ -130,10 +130,10 @@ function finetune_ui_blocks_js() { 
     | 
|
| 130 | 
         | 
| 131 | 
         
             
              // Show/hide start and stop button base on the state.
         
     | 
| 132 | 
         
             
              setTimeout(function () {
         
     | 
| 133 | 
         
            -
                // Make the '# 
     | 
| 134 | 
         
            -
                if (!document.querySelector('# 
     | 
| 135 | 
         
            -
             
     | 
| 136 | 
         
            -
                }
         
     | 
| 137 | 
         | 
| 138 | 
         
             
                setTimeout(function () {
         
     | 
| 139 | 
         
             
                  let resetStopButtonTimer;
         
     | 
| 
         @@ -156,11 +156,20 @@ function finetune_ui_blocks_js() { 
     | 
|
| 156 | 
         
             
                      document.getElementById('finetune_confirm_stop_btn').style.display =
         
     | 
| 157 | 
         
             
                        'block';
         
     | 
| 158 | 
         
             
                    });
         
     | 
| 159 | 
         
            -
                  const  
     | 
| 160 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 161 | 
         
             
                  );
         
     | 
| 162 | 
         
            -
                   
     | 
| 163 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 164 | 
         
             
                      if (resetStopButtonTimer) clearTimeout(resetStopButtonTimer);
         
     | 
| 165 | 
         
             
                      document.getElementById('finetune_start_btn').style.display = 'block';
         
     | 
| 166 | 
         
             
                      document.getElementById('finetune_stop_btn').style.display = 'none';
         
     | 
| 
         @@ -173,13 +182,19 @@ function finetune_ui_blocks_js() { 
     | 
|
| 173 | 
         
             
                        'none';
         
     | 
| 174 | 
         
             
                    }
         
     | 
| 175 | 
         
             
                  }
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 176 | 
         
             
                  new MutationObserver(function (mutationsList, observer) {
         
     | 
| 177 | 
         
            -
                     
     | 
| 178 | 
         
            -
                  }).observe( 
     | 
| 179 | 
         
             
                    attributes: true,
         
     | 
| 180 | 
         
             
                    attributeFilter: ['class'],
         
     | 
| 181 | 
         
             
                  });
         
     | 
| 182 | 
         
            -
                   
     | 
| 183 | 
         
             
                }, 500);
         
     | 
| 184 | 
         
             
              }, 0);
         
     | 
| 185 | 
         
             
            }
         
     | 
| 
         | 
|
| 130 | 
         | 
| 131 | 
         
             
              // Show/hide start and stop button base on the state.
         
     | 
| 132 | 
         
             
              setTimeout(function () {
         
     | 
| 133 | 
         
            +
                // Make the '#finetune_training_indicator > .wrap' element appear
         
     | 
| 134 | 
         
            +
                // if (!document.querySelector('#finetune_training_indicator > .wrap')) {
         
     | 
| 135 | 
         
            +
                //   document.getElementById('finetune_confirm_stop_btn').click();
         
     | 
| 136 | 
         
            +
                // }
         
     | 
| 137 | 
         | 
| 138 | 
         
             
                setTimeout(function () {
         
     | 
| 139 | 
         
             
                  let resetStopButtonTimer;
         
     | 
| 
         | 
|
| 156 | 
         
             
                      document.getElementById('finetune_confirm_stop_btn').style.display =
         
     | 
| 157 | 
         
             
                        'block';
         
     | 
| 158 | 
         
             
                    });
         
     | 
| 159 | 
         
            +
                  // const training_indicator_wrap_element = document.querySelector(
         
     | 
| 160 | 
         
            +
                  //   '#finetune_training_indicator > .wrap'
         
     | 
| 161 | 
         
            +
                  // );
         
     | 
| 162 | 
         
            +
                  const training_indicator_element = document.querySelector(
         
     | 
| 163 | 
         
            +
                    '#finetune_training_indicator'
         
     | 
| 164 | 
         
             
                  );
         
     | 
| 165 | 
         
            +
                  let isTraining = undefined;
         
     | 
| 166 | 
         
            +
                  function handle_training_indicator_change() {
         
     | 
| 167 | 
         
            +
                    // const wrapperHidden = Array.from(training_indicator_wrap_element.classList).includes('hide');
         
     | 
| 168 | 
         
            +
                    const hidden = Array.from(training_indicator_element.classList).includes('hidden');
         
     | 
| 169 | 
         
            +
                    const newIsTraining = !(/* wrapperHidden && */ hidden);
         
     | 
| 170 | 
         
            +
                    if (newIsTraining === isTraining) return;
         
     | 
| 171 | 
         
            +
                    isTraining = newIsTraining;
         
     | 
| 172 | 
         
            +
                    if (!isTraining) {
         
     | 
| 173 | 
         
             
                      if (resetStopButtonTimer) clearTimeout(resetStopButtonTimer);
         
     | 
| 174 | 
         
             
                      document.getElementById('finetune_start_btn').style.display = 'block';
         
     | 
| 175 | 
         
             
                      document.getElementById('finetune_stop_btn').style.display = 'none';
         
     | 
| 
         | 
|
| 182 | 
         
             
                        'none';
         
     | 
| 183 | 
         
             
                    }
         
     | 
| 184 | 
         
             
                  }
         
     | 
| 185 | 
         
            +
                  // new MutationObserver(function (mutationsList, observer) {
         
     | 
| 186 | 
         
            +
                  //   handle_training_indicator_change();
         
     | 
| 187 | 
         
            +
                  // }).observe(training_indicator_wrap_element, {
         
     | 
| 188 | 
         
            +
                  //   attributes: true,
         
     | 
| 189 | 
         
            +
                  //   attributeFilter: ['class'],
         
     | 
| 190 | 
         
            +
                  // });
         
     | 
| 191 | 
         
             
                  new MutationObserver(function (mutationsList, observer) {
         
     | 
| 192 | 
         
            +
                    handle_training_indicator_change();
         
     | 
| 193 | 
         
            +
                  }).observe(training_indicator_element, {
         
     | 
| 194 | 
         
             
                    attributes: true,
         
     | 
| 195 | 
         
             
                    attributeFilter: ['class'],
         
     | 
| 196 | 
         
             
                  });
         
     | 
| 197 | 
         
            +
                  handle_training_indicator_change();
         
     | 
| 198 | 
         
             
                }, 500);
         
     | 
| 199 | 
         
             
              }, 0);
         
     | 
| 200 | 
         
             
            }
         
     | 
    	
        llama_lora/ui/finetune/style.css
    CHANGED
    
    | 
         @@ -255,8 +255,118 @@ 
     | 
|
| 255 | 
         
             
                display: none;
         
     | 
| 256 | 
         
             
            }
         
     | 
| 257 | 
         | 
| 258 | 
         
            -
             
     | 
| 259 | 
         
            -
             
     | 
| 260 | 
         
            -
                 
     | 
| 261 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 262 | 
         
             
            }
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 255 | 
         
             
                display: none;
         
     | 
| 256 | 
         
             
            }
         
     | 
| 257 | 
         | 
| 258 | 
         
            +
            #finetune_training_status > .wrap {
         
     | 
| 259 | 
         
            +
                border: 0;
         
     | 
| 260 | 
         
            +
                background: transparent;
         
     | 
| 261 | 
         
            +
                pointer-events: none;
         
     | 
| 262 | 
         
            +
                top: 0;
         
     | 
| 263 | 
         
            +
                bottom: 0;
         
     | 
| 264 | 
         
            +
                left: 0;
         
     | 
| 265 | 
         
            +
                right: 0;
         
     | 
| 266 | 
         
            +
            }
         
     | 
| 267 | 
         
            +
            #finetune_training_status > .wrap .meta-text-center {
         
     | 
| 268 | 
         
            +
                transform: none !important;
         
     | 
| 269 | 
         
            +
            }
         
     | 
| 270 | 
         
            +
             
     | 
| 271 | 
         
            +
            #finetune_training_status .progress-block {
         
     | 
| 272 | 
         
            +
                min-height: 100px;
         
     | 
| 273 | 
         
            +
                display: flex;
         
     | 
| 274 | 
         
            +
                justify-content: center;
         
     | 
| 275 | 
         
            +
                align-items: center;
         
     | 
| 276 | 
         
            +
                background: var(--panel-background-fill);
         
     | 
| 277 | 
         
            +
                border-radius: var(--radius-lg);
         
     | 
| 278 | 
         
            +
                border: var(--block-border-width) solid var(--border-color-primary);
         
     | 
| 279 | 
         
            +
                padding: var(--block-padding);
         
     | 
| 280 | 
         
            +
            }
         
     | 
| 281 | 
         
            +
            #finetune_training_status .progress-block.is_training {
         
     | 
| 282 | 
         
            +
                min-height: 160px;
         
     | 
| 283 | 
         
            +
            }
         
     | 
| 284 | 
         
            +
            #finetune_training_status .progress-block .empty-text {
         
     | 
| 285 | 
         
            +
                text-transform: uppercase;
         
     | 
| 286 | 
         
            +
                font-weight: 700;
         
     | 
| 287 | 
         
            +
                font-size: 120%;
         
     | 
| 288 | 
         
            +
                opacity: 0.12;
         
     | 
| 289 | 
         
            +
            }
         
     | 
| 290 | 
         
            +
            #finetune_training_status .progress-block .meta-text {
         
     | 
| 291 | 
         
            +
                position: absolute;
         
     | 
| 292 | 
         
            +
                top: 0;
         
     | 
| 293 | 
         
            +
                right: 0;
         
     | 
| 294 | 
         
            +
                z-index: var(--layer-2);
         
     | 
| 295 | 
         
            +
                padding: var(--size-1) var(--size-2);
         
     | 
| 296 | 
         
            +
                font-size: var(--text-sm);
         
     | 
| 297 | 
         
            +
                font-family: var(--font-mono);
         
     | 
| 298 | 
         
            +
            }
         
     | 
| 299 | 
         
            +
            #finetune_training_status .progress-block .status {
         
     | 
| 300 | 
         
            +
                white-space: pre-wrap;
         
     | 
| 301 | 
         
            +
            }
         
     | 
| 302 | 
         
            +
            #finetune_training_status .progress-block .progress-level {
         
     | 
| 303 | 
         
            +
                display: flex;
         
     | 
| 304 | 
         
            +
                flex-direction: column;
         
     | 
| 305 | 
         
            +
                align-items: center;
         
     | 
| 306 | 
         
            +
                z-index: var(--layer-2);
         
     | 
| 307 | 
         
            +
                width: var(--size-full);
         
     | 
| 308 | 
         
            +
            }
         
     | 
| 309 | 
         
            +
            #finetune_training_status .progress-block .progress-level-inner {
         
     | 
| 310 | 
         
            +
                margin: var(--size-2) auto;
         
     | 
| 311 | 
         
            +
                color: var(--body-text-color);
         
     | 
| 312 | 
         
            +
                font-size: var(--text-sm);
         
     | 
| 313 | 
         
            +
                font-family: var(--font-mono);
         
     | 
| 314 | 
         
            +
            }
         
     | 
| 315 | 
         
            +
            #finetune_training_status .progress-block .progress-bar-wrap {
         
     | 
| 316 | 
         
            +
                border: 1px solid var(--border-color-primary);
         
     | 
| 317 | 
         
            +
                background: var(--background-fill-primary);
         
     | 
| 318 | 
         
            +
                width: 55.5%;
         
     | 
| 319 | 
         
            +
                height: var(--size-4);
         
     | 
| 320 | 
         
            +
            }
         
     | 
| 321 | 
         
            +
            #finetune_training_status .progress-block .progress-bar {
         
     | 
| 322 | 
         
            +
                transform-origin: left;
         
     | 
| 323 | 
         
            +
                background-color: var(--loader-color);
         
     | 
| 324 | 
         
            +
                width: var(--size-full);
         
     | 
| 325 | 
         
            +
                height: var(--size-full);
         
     | 
| 326 | 
         
            +
                transition: all 150ms ease 0s;
         
     | 
| 327 | 
         
             
            }
         
     | 
| 328 | 
         
            +
             
     | 
| 329 | 
         
            +
            #finetune_training_status .progress-block .output {
         
     | 
| 330 | 
         
            +
                display: flex;
         
     | 
| 331 | 
         
            +
                flex-direction: column;
         
     | 
| 332 | 
         
            +
                justify-content: center;
         
     | 
| 333 | 
         
            +
                align-items: center;
         
     | 
| 334 | 
         
            +
            }
         
     | 
| 335 | 
         
            +
            #finetune_training_status .progress-block .output .title {
         
     | 
| 336 | 
         
            +
                padding: var(--size-1) var(--size-3);
         
     | 
| 337 | 
         
            +
                font-weight: var(--weight-bold);
         
     | 
| 338 | 
         
            +
                font-size: var(--text-lg);
         
     | 
| 339 | 
         
            +
                line-height: var(--line-xs);
         
     | 
| 340 | 
         
            +
            }
         
     | 
| 341 | 
         
            +
            #finetune_training_status .progress-block .output .message {
         
     | 
| 342 | 
         
            +
                padding: var(--size-1) var(--size-3);
         
     | 
| 343 | 
         
            +
                color: var(--body-text-color) !important;
         
     | 
| 344 | 
         
            +
                font-family: var(--font-mono);
         
     | 
| 345 | 
         
            +
                white-space: pre-wrap;
         
     | 
| 346 | 
         
            +
            }
         
     | 
| 347 | 
         
            +
             
     | 
| 348 | 
         
            +
            #finetune_training_status .progress-block .error {
         
     | 
| 349 | 
         
            +
                display: flex;
         
     | 
| 350 | 
         
            +
                flex-direction: column;
         
     | 
| 351 | 
         
            +
                justify-content: center;
         
     | 
| 352 | 
         
            +
                align-items: center;
         
     | 
| 353 | 
         
            +
            }
         
     | 
| 354 | 
         
            +
            #finetune_training_status .progress-block .error .title {
         
     | 
| 355 | 
         
            +
                padding: var(--size-1) var(--size-3);
         
     | 
| 356 | 
         
            +
                color: var(--color-red-500);
         
     | 
| 357 | 
         
            +
                font-weight: var(--weight-bold);
         
     | 
| 358 | 
         
            +
                font-size: var(--text-lg);
         
     | 
| 359 | 
         
            +
                line-height: var(--line-xs);
         
     | 
| 360 | 
         
            +
            }
         
     | 
| 361 | 
         
            +
            #finetune_training_status .progress-block .error .error-message {
         
     | 
| 362 | 
         
            +
                padding: var(--size-1) var(--size-3);
         
     | 
| 363 | 
         
            +
                color: var(--body-text-color) !important;
         
     | 
| 364 | 
         
            +
                font-family: var(--font-mono);
         
     | 
| 365 | 
         
            +
                white-space: pre-wrap;
         
     | 
| 366 | 
         
            +
            }
         
     | 
| 367 | 
         
            +
            #finetune_training_status .progress-block.is_error {
         
     | 
| 368 | 
         
            +
                /* background: var(--error-background-fill) !important; */
         
     | 
| 369 | 
         
            +
                border: 1px solid var(--error-border-color) !important;
         
     | 
| 370 | 
         
            +
            }
         
     | 
| 371 | 
         
            +
             
     | 
| 372 | 
         
            +
            #finetune_training_indicator { display: none; }
         
     | 
    	
        llama_lora/ui/finetune/training.py
    CHANGED
    
    | 
         @@ -1,24 +1,26 @@ 
     | 
|
| 1 | 
         
             
            import os
         
     | 
| 2 | 
         
             
            import json
         
     | 
| 3 | 
         
             
            import time
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 4 | 
         
             
            import gradio as gr
         
     | 
| 5 | 
         
            -
            import math
         
     | 
| 6 | 
         | 
| 7 | 
         
            -
            from transformers import TrainerCallback
         
     | 
| 8 | 
         
             
            from huggingface_hub import try_to_load_from_cache, snapshot_download
         
     | 
| 9 | 
         | 
| 10 | 
         
             
            from ...config import Config
         
     | 
| 11 | 
         
             
            from ...globals import Global
         
     | 
| 12 | 
         
             
            from ...models import clear_cache, unload_models
         
     | 
| 13 | 
         
             
            from ...utils.prompter import Prompter
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 14 | 
         | 
| 15 | 
         
             
            from .data_processing import get_data_from_input
         
     | 
| 16 | 
         | 
| 17 | 
         
            -
            should_training_progress_track_tqdm = True
         
     | 
| 18 | 
         
            -
             
     | 
| 19 | 
         
            -
            if Global.gpu_total_cores is not None and Global.gpu_total_cores > 2560:
         
     | 
| 20 | 
         
            -
                should_training_progress_track_tqdm = False
         
     | 
| 21 | 
         
            -
             
     | 
| 22 | 
         | 
| 23 | 
         
             
            def do_train(
         
     | 
| 24 | 
         
             
                # Dataset
         
     | 
| 
         @@ -55,8 +57,14 @@ def do_train( 
     | 
|
| 55 | 
         
             
                model_name,
         
     | 
| 56 | 
         
             
                continue_from_model,
         
     | 
| 57 | 
         
             
                continue_from_checkpoint,
         
     | 
| 58 | 
         
            -
                progress=gr.Progress(track_tqdm= 
     | 
| 59 | 
         
             
            ):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 60 | 
         
             
                try:
         
     | 
| 61 | 
         
             
                    base_model_name = Global.base_model_name
         
     | 
| 62 | 
         
             
                    tokenizer_name = Global.tokenizer_name or Global.base_model_name
         
     | 
| 
         @@ -115,18 +123,47 @@ def do_train( 
     | 
|
| 115 | 
         
             
                            raise ValueError(
         
     | 
| 116 | 
         
             
                                f"The output directory already exists and is not empty. ({output_dir})")
         
     | 
| 117 | 
         | 
| 118 | 
         
            -
                     
     | 
| 119 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 120 | 
         | 
| 121 | 
         
            -
                     
     | 
| 122 | 
         
            -
             
     | 
| 123 | 
         
            -
             
     | 
| 124 | 
         
            -
             
     | 
| 125 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 126 | 
         | 
| 127 | 
         
             
                    prompter = Prompter(template)
         
     | 
| 128 | 
         
            -
                    # variable_names = prompter.get_variable_names()
         
     | 
| 129 | 
         
            -
             
     | 
| 130 | 
         
             
                    data = get_data_from_input(
         
     | 
| 131 | 
         
             
                        load_dataset_from=load_dataset_from,
         
     | 
| 132 | 
         
             
                        dataset_text=dataset_text,
         
     | 
| 
         @@ -138,208 +175,234 @@ def do_train( 
     | 
|
| 138 | 
         
             
                        prompter=prompter
         
     | 
| 139 | 
         
             
                    )
         
     | 
| 140 | 
         | 
| 141 | 
         
            -
                     
     | 
| 142 | 
         
            -
             
     | 
| 143 | 
         
            -
             
     | 
| 144 | 
         
            -
                         
     | 
| 145 | 
         
            -
             
     | 
| 146 | 
         
            -
                             
     | 
| 147 | 
         
            -
             
     | 
| 148 | 
         
            -
             
     | 
| 149 | 
         
            -
             
     | 
| 150 | 
         
            -
             
     | 
| 151 | 
         
            -
             
     | 
| 152 | 
         
            -
             
     | 
| 153 | 
         
            -
             
     | 
| 154 | 
         
            -
             
     | 
| 155 | 
         
            -
             
     | 
| 156 | 
         
            -
             
     | 
| 157 | 
         
            -
             
     | 
| 158 | 
         
            -
             
     | 
| 159 | 
         
            -
             
     | 
| 160 | 
         
            -
             
     | 
| 161 | 
         
            -
             
     | 
| 162 | 
         
            -
             
     | 
| 163 | 
         
            -
             
     | 
| 164 | 
         
            -
             
     | 
| 165 | 
         
            -
             
     | 
| 166 | 
         
            -
             
     | 
| 167 | 
         
            -
             
     | 
| 168 | 
         
            -
             
     | 
| 169 | 
         
            -
             
     | 
| 170 | 
         
            -
             
     | 
| 171 | 
         
            -
             
     | 
| 172 | 
         
            -
             
     | 
| 173 | 
         
            -
             
     | 
| 174 | 
         
            -
             
     | 
| 175 | 
         
            -
             
     | 
| 176 | 
         
            -
             
     | 
| 177 | 
         
            -
             
     | 
| 178 | 
         
            -
             
     | 
| 179 | 
         
            -
             
     | 
| 180 | 
         
            -
             
     | 
| 181 | 
         
            -
             
     | 
| 182 | 
         
            -
             
     | 
| 183 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 184 | 
         
             
                                return
         
     | 
| 185 | 
         
            -
                            epochs = 3
         
     | 
| 186 | 
         
            -
                            epoch = i / 100
         
     | 
| 187 | 
         
            -
                            last_loss = None
         
     | 
| 188 | 
         
            -
                            if (i > 20):
         
     | 
| 189 | 
         
            -
                                last_loss = 3 + (i - 0) * (0.5 - 3) / (300 - 0)
         
     | 
| 190 | 
         
            -
             
     | 
| 191 | 
         
            -
                            progress(
         
     | 
| 192 | 
         
            -
                                (i, 300),
         
     | 
| 193 | 
         
            -
                                desc="(Simulate) " +
         
     | 
| 194 | 
         
            -
                                get_progress_text(epoch, epochs, last_loss)
         
     | 
| 195 | 
         
            -
                            )
         
     | 
| 196 | 
         
            -
             
     | 
| 197 | 
         
            -
                            time.sleep(0.1)
         
     | 
| 198 | 
         
            -
             
     | 
| 199 | 
         
            -
                        time.sleep(2)
         
     | 
| 200 | 
         
            -
                        return message
         
     | 
| 201 | 
         
            -
             
     | 
| 202 | 
         
            -
                    if not should_training_progress_track_tqdm:
         
     | 
| 203 | 
         
            -
                        progress(
         
     | 
| 204 | 
         
            -
                            0, desc=f"Preparing model {base_model_name} for training...")
         
     | 
| 205 | 
         
            -
             
     | 
| 206 | 
         
            -
                    log_history = []
         
     | 
| 207 | 
         
            -
             
     | 
| 208 | 
         
            -
                    class UiTrainerCallback(TrainerCallback):
         
     | 
| 209 | 
         
            -
                        def _on_progress(self, args, state, control):
         
     | 
| 210 | 
         
            -
                            nonlocal log_history
         
     | 
| 211 | 
         | 
| 212 | 
         
            -
                             
     | 
| 213 | 
         
            -
             
     | 
| 214 | 
         
            -
                             
     | 
| 215 | 
         
            -
                                 
     | 
| 216 | 
         
            -
             
     | 
| 217 | 
         
            -
                             
     | 
| 218 | 
         
            -
             
     | 
| 219 | 
         
            -
             
     | 
| 220 | 
         
            -
             
     | 
| 221 | 
         
            -
             
     | 
| 222 | 
         
            -
             
     | 
| 223 | 
         
            -
             
     | 
| 224 | 
         
            -
             
     | 
| 225 | 
         
            -
             
     | 
| 226 | 
         
            -
             
     | 
| 227 | 
         
            -
             
     | 
| 228 | 
         
            -
             
     | 
| 229 | 
         
            -
                                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 230 | 
         
             
                            )
         
     | 
| 231 | 
         | 
| 232 | 
         
            -
             
     | 
| 233 | 
         
            -
                             
     | 
| 234 | 
         
            -
             
     | 
| 235 | 
         
            -
                        def on_step_end(self, args, state, control, **kwargs):
         
     | 
| 236 | 
         
            -
                            self._on_progress(args, state, control)
         
     | 
| 237 | 
         
            -
             
     | 
| 238 | 
         
            -
                    training_callbacks = [UiTrainerCallback]
         
     | 
| 239 | 
         
            -
             
     | 
| 240 | 
         
            -
                    Global.should_stop_training = False
         
     | 
| 241 | 
         
            -
             
     | 
| 242 | 
         
            -
                    # Do not let other tqdm iterations interfere the progress reporting after training starts.
         
     | 
| 243 | 
         
            -
                    # progress.track_tqdm = False  # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
         
     | 
| 244 | 
         | 
| 245 | 
         
            -
             
     | 
| 246 | 
         
            -
                        os.makedirs(output_dir)
         
     | 
| 247 | 
         | 
| 248 | 
         
            -
             
     | 
| 249 | 
         
            -
                        dataset_name = "N/A (from text input)"
         
     | 
| 250 | 
         
            -
                        if load_dataset_from == "Data Dir":
         
     | 
| 251 | 
         
            -
                            dataset_name = dataset_from_data_dir
         
     | 
| 252 | 
         | 
| 253 | 
         
            -
                         
     | 
| 254 | 
         
            -
                             
     | 
| 255 | 
         
            -
                             
     | 
| 256 | 
         
            -
             
     | 
| 257 | 
         
            -
                             
     | 
| 258 | 
         
            -
                            'timestamp': time.time(),
         
     | 
| 259 | 
         | 
| 260 | 
         
            -
             
     | 
| 261 | 
         
            -
             
     | 
| 262 | 
         
            -
             
     | 
| 263 | 
         
            -
             
     | 
| 264 | 
         
            -
                            # 'micro_batch_size': micro_batch_size,
         
     | 
| 265 | 
         
            -
                            # 'gradient_accumulation_steps': gradient_accumulation_steps,
         
     | 
| 266 | 
         
            -
                            # 'epochs': epochs,
         
     | 
| 267 | 
         
            -
                            # 'learning_rate': learning_rate,
         
     | 
| 268 | 
         
            -
             
     | 
| 269 | 
         
            -
                            # 'evaluate_data_count': evaluate_data_count,
         
     | 
| 270 | 
         
            -
             
     | 
| 271 | 
         
            -
                            # 'lora_r': lora_r,
         
     | 
| 272 | 
         
            -
                            # 'lora_alpha': lora_alpha,
         
     | 
| 273 | 
         
            -
                            # 'lora_dropout': lora_dropout,
         
     | 
| 274 | 
         
            -
                            # 'lora_target_modules': lora_target_modules,
         
     | 
| 275 | 
         
            -
                        }
         
     | 
| 276 | 
         
            -
                        if continue_from_model:
         
     | 
| 277 | 
         
            -
                            info['continued_from_model'] = continue_from_model
         
     | 
| 278 | 
         
            -
                            if continue_from_checkpoint:
         
     | 
| 279 | 
         
            -
                                info['continued_from_checkpoint'] = continue_from_checkpoint
         
     | 
| 280 | 
         
            -
             
     | 
| 281 | 
         
            -
                        if Global.version:
         
     | 
| 282 | 
         
            -
                            info['tuner_version'] = Global.version
         
     | 
| 283 | 
         
            -
             
     | 
| 284 | 
         
            -
                        json.dump(info, info_json_file, indent=2)
         
     | 
| 285 | 
         
            -
             
     | 
| 286 | 
         
            -
                    if not should_training_progress_track_tqdm:
         
     | 
| 287 | 
         
            -
                        progress(0, desc="Train starting...")
         
     | 
| 288 | 
         
            -
             
     | 
| 289 | 
         
            -
                    wandb_group = template
         
     | 
| 290 | 
         
            -
                    wandb_tags = [f"template:{template}"]
         
     | 
| 291 | 
         
            -
                    if load_dataset_from == "Data Dir" and dataset_from_data_dir:
         
     | 
| 292 | 
         
            -
                        wandb_group += f"/{dataset_from_data_dir}"
         
     | 
| 293 | 
         
            -
                        wandb_tags.append(f"dataset:{dataset_from_data_dir}")
         
     | 
| 294 | 
         
            -
             
     | 
| 295 | 
         
            -
                    train_output = Global.finetune_train_fn(
         
     | 
| 296 | 
         
            -
                        base_model=base_model_name,
         
     | 
| 297 | 
         
            -
                        tokenizer=tokenizer_name,
         
     | 
| 298 | 
         
            -
                        output_dir=output_dir,
         
     | 
| 299 | 
         
            -
                        train_data=train_data,
         
     | 
| 300 | 
         
            -
                        # 128,  # batch_size (is not used, use gradient_accumulation_steps instead)
         
     | 
| 301 | 
         
            -
                        micro_batch_size=micro_batch_size,
         
     | 
| 302 | 
         
            -
                        gradient_accumulation_steps=gradient_accumulation_steps,
         
     | 
| 303 | 
         
            -
                        num_train_epochs=epochs,
         
     | 
| 304 | 
         
            -
                        learning_rate=learning_rate,
         
     | 
| 305 | 
         
            -
                        cutoff_len=max_seq_length,
         
     | 
| 306 | 
         
            -
                        val_set_size=evaluate_data_count,
         
     | 
| 307 | 
         
            -
                        lora_r=lora_r,
         
     | 
| 308 | 
         
            -
                        lora_alpha=lora_alpha,
         
     | 
| 309 | 
         
            -
                        lora_dropout=lora_dropout,
         
     | 
| 310 | 
         
            -
                        lora_target_modules=lora_target_modules,
         
     | 
| 311 | 
         
            -
                        lora_modules_to_save=lora_modules_to_save,
         
     | 
| 312 | 
         
            -
                        train_on_inputs=train_on_inputs,
         
     | 
| 313 | 
         
            -
                        load_in_8bit=load_in_8bit,
         
     | 
| 314 | 
         
            -
                        fp16=fp16,
         
     | 
| 315 | 
         
            -
                        bf16=bf16,
         
     | 
| 316 | 
         
            -
                        gradient_checkpointing=gradient_checkpointing,
         
     | 
| 317 | 
         
            -
                        group_by_length=False,
         
     | 
| 318 | 
         
            -
                        resume_from_checkpoint=resume_from_checkpoint_param,
         
     | 
| 319 | 
         
            -
                        save_steps=save_steps,
         
     | 
| 320 | 
         
            -
                        save_total_limit=save_total_limit,
         
     | 
| 321 | 
         
            -
                        logging_steps=logging_steps,
         
     | 
| 322 | 
         
            -
                        additional_training_arguments=additional_training_arguments,
         
     | 
| 323 | 
         
            -
                        additional_lora_config=additional_lora_config,
         
     | 
| 324 | 
         
            -
                        callbacks=training_callbacks,
         
     | 
| 325 | 
         
            -
                        wandb_api_key=Config.wandb_api_key,
         
     | 
| 326 | 
         
            -
                        wandb_project=Config.default_wandb_project if Config.enable_wandb else None,
         
     | 
| 327 | 
         
            -
                        wandb_group=wandb_group,
         
     | 
| 328 | 
         
            -
                        wandb_run_name=model_name,
         
     | 
| 329 | 
         
            -
                        wandb_tags=wandb_tags
         
     | 
| 330 | 
         
            -
                    )
         
     | 
| 331 | 
         
            -
             
     | 
| 332 | 
         
            -
                    logs_str = "\n".join([json.dumps(log)
         
     | 
| 333 | 
         
            -
                                         for log in log_history]) or "None"
         
     | 
| 334 | 
         
            -
             
     | 
| 335 | 
         
            -
                    result_message = f"Training ended:\n{str(train_output)}"
         
     | 
| 336 | 
         
            -
                    print(result_message)
         
     | 
| 337 | 
         
            -
                    # result_message += f"\n\nLogs:\n{logs_str}"
         
     | 
| 338 | 
         
            -
             
     | 
| 339 | 
         
            -
                    clear_cache()
         
     | 
| 340 | 
         
            -
             
     | 
| 341 | 
         
            -
                    return result_message
         
     | 
| 342 | 
         | 
| 343 | 
         
             
                except Exception as e:
         
     | 
| 344 | 
         
            -
                     
     | 
| 345 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
             
            import os
         
     | 
| 2 | 
         
             
            import json
         
     | 
| 3 | 
         
             
            import time
         
     | 
| 4 | 
         
            +
            import datetime
         
     | 
| 5 | 
         
            +
            import pytz
         
     | 
| 6 | 
         
            +
            import socket
         
     | 
| 7 | 
         
            +
            import threading
         
     | 
| 8 | 
         
            +
            import traceback
         
     | 
| 9 | 
         
             
            import gradio as gr
         
     | 
| 
         | 
|
| 10 | 
         | 
| 
         | 
|
| 11 | 
         
             
            from huggingface_hub import try_to_load_from_cache, snapshot_download
         
     | 
| 12 | 
         | 
| 13 | 
         
             
            from ...config import Config
         
     | 
| 14 | 
         
             
            from ...globals import Global
         
     | 
| 15 | 
         
             
            from ...models import clear_cache, unload_models
         
     | 
| 16 | 
         
             
            from ...utils.prompter import Prompter
         
     | 
| 17 | 
         
            +
            from ..trainer_callback import (
         
     | 
| 18 | 
         
            +
                UiTrainerCallback, reset_training_status,
         
     | 
| 19 | 
         
            +
                update_training_states, set_train_output
         
     | 
| 20 | 
         
            +
            )
         
     | 
| 21 | 
         | 
| 22 | 
         
             
            from .data_processing import get_data_from_input
         
     | 
| 23 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 24 | 
         | 
| 25 | 
         
             
            def do_train(
         
     | 
| 26 | 
         
             
                # Dataset
         
     | 
| 
         | 
|
| 57 | 
         
             
                model_name,
         
     | 
| 58 | 
         
             
                continue_from_model,
         
     | 
| 59 | 
         
             
                continue_from_checkpoint,
         
     | 
| 60 | 
         
            +
                progress=gr.Progress(track_tqdm=False),
         
     | 
| 61 | 
         
             
            ):
         
     | 
| 62 | 
         
            +
                if Global.is_training:
         
     | 
| 63 | 
         
            +
                    return render_training_status()
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                reset_training_status()
         
     | 
| 66 | 
         
            +
                Global.is_train_starting = True
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
             
                try:
         
     | 
| 69 | 
         
             
                    base_model_name = Global.base_model_name
         
     | 
| 70 | 
         
             
                    tokenizer_name = Global.tokenizer_name or Global.base_model_name
         
     | 
| 
         | 
|
| 123 | 
         
             
                            raise ValueError(
         
     | 
| 124 | 
         
             
                                f"The output directory already exists and is not empty. ({output_dir})")
         
     | 
| 125 | 
         | 
| 126 | 
         
            +
                    wandb_group = template
         
     | 
| 127 | 
         
            +
                    wandb_tags = [f"template:{template}"]
         
     | 
| 128 | 
         
            +
                    if load_dataset_from == "Data Dir" and dataset_from_data_dir:
         
     | 
| 129 | 
         
            +
                        wandb_group += f"/{dataset_from_data_dir}"
         
     | 
| 130 | 
         
            +
                        wandb_tags.append(f"dataset:{dataset_from_data_dir}")
         
     | 
| 131 | 
         | 
| 132 | 
         
            +
                    finetune_args = {
         
     | 
| 133 | 
         
            +
                        'base_model': base_model_name,
         
     | 
| 134 | 
         
            +
                        'tokenizer': tokenizer_name,
         
     | 
| 135 | 
         
            +
                        'output_dir': output_dir,
         
     | 
| 136 | 
         
            +
                        'micro_batch_size': micro_batch_size,
         
     | 
| 137 | 
         
            +
                        'gradient_accumulation_steps': gradient_accumulation_steps,
         
     | 
| 138 | 
         
            +
                        'num_train_epochs': epochs,
         
     | 
| 139 | 
         
            +
                        'learning_rate': learning_rate,
         
     | 
| 140 | 
         
            +
                        'cutoff_len': max_seq_length,
         
     | 
| 141 | 
         
            +
                        'val_set_size': evaluate_data_count,
         
     | 
| 142 | 
         
            +
                        'lora_r': lora_r,
         
     | 
| 143 | 
         
            +
                        'lora_alpha': lora_alpha,
         
     | 
| 144 | 
         
            +
                        'lora_dropout': lora_dropout,
         
     | 
| 145 | 
         
            +
                        'lora_target_modules': lora_target_modules,
         
     | 
| 146 | 
         
            +
                        'lora_modules_to_save': lora_modules_to_save,
         
     | 
| 147 | 
         
            +
                        'train_on_inputs': train_on_inputs,
         
     | 
| 148 | 
         
            +
                        'load_in_8bit': load_in_8bit,
         
     | 
| 149 | 
         
            +
                        'fp16': fp16,
         
     | 
| 150 | 
         
            +
                        'bf16': bf16,
         
     | 
| 151 | 
         
            +
                        'gradient_checkpointing': gradient_checkpointing,
         
     | 
| 152 | 
         
            +
                        'group_by_length': False,
         
     | 
| 153 | 
         
            +
                        'resume_from_checkpoint': resume_from_checkpoint_param,
         
     | 
| 154 | 
         
            +
                        'save_steps': save_steps,
         
     | 
| 155 | 
         
            +
                        'save_total_limit': save_total_limit,
         
     | 
| 156 | 
         
            +
                        'logging_steps': logging_steps,
         
     | 
| 157 | 
         
            +
                        'additional_training_arguments': additional_training_arguments,
         
     | 
| 158 | 
         
            +
                        'additional_lora_config': additional_lora_config,
         
     | 
| 159 | 
         
            +
                        'wandb_api_key': Config.wandb_api_key,
         
     | 
| 160 | 
         
            +
                        'wandb_project': Config.default_wandb_project if Config.enable_wandb else None,
         
     | 
| 161 | 
         
            +
                        'wandb_group': wandb_group,
         
     | 
| 162 | 
         
            +
                        'wandb_run_name': model_name,
         
     | 
| 163 | 
         
            +
                        'wandb_tags': wandb_tags
         
     | 
| 164 | 
         
            +
                    }
         
     | 
| 165 | 
         | 
| 166 | 
         
             
                    prompter = Prompter(template)
         
     | 
| 
         | 
|
| 
         | 
|
| 167 | 
         
             
                    data = get_data_from_input(
         
     | 
| 168 | 
         
             
                        load_dataset_from=load_dataset_from,
         
     | 
| 169 | 
         
             
                        dataset_text=dataset_text,
         
     | 
| 
         | 
|
| 175 | 
         
             
                        prompter=prompter
         
     | 
| 176 | 
         
             
                    )
         
     | 
| 177 | 
         | 
| 178 | 
         
            +
                    def training():
         
     | 
| 179 | 
         
            +
                        Global.is_training = True
         
     | 
| 180 | 
         
            +
             
     | 
| 181 | 
         
            +
                        try:
         
     | 
| 182 | 
         
            +
                            # Need RAM for training
         
     | 
| 183 | 
         
            +
                            unload_models()
         
     | 
| 184 | 
         
            +
                            Global.new_base_model_that_is_ready_to_be_used = None
         
     | 
| 185 | 
         
            +
                            Global.name_of_new_base_model_that_is_ready_to_be_used = None
         
     | 
| 186 | 
         
            +
                            clear_cache()
         
     | 
| 187 | 
         
            +
             
     | 
| 188 | 
         
            +
                            train_data = prompter.get_train_data_from_dataset(data)
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                            if Config.ui_dev_mode:
         
     | 
| 191 | 
         
            +
                                message = "Currently in UI dev mode, not doing the actual training."
         
     | 
| 192 | 
         
            +
                                message += f"\n\nArgs: {json.dumps(finetune_args, indent=2)}"
         
     | 
| 193 | 
         
            +
                                message += f"\n\nTrain data (first 5):\n{json.dumps(train_data[:5], indent=2)}"
         
     | 
| 194 | 
         
            +
             
     | 
| 195 | 
         
            +
                                print(message)
         
     | 
| 196 | 
         
            +
             
     | 
| 197 | 
         
            +
                                total_steps = 300
         
     | 
| 198 | 
         
            +
                                for i in range(300):
         
     | 
| 199 | 
         
            +
                                    if (Global.should_stop_training):
         
     | 
| 200 | 
         
            +
                                        break
         
     | 
| 201 | 
         
            +
             
     | 
| 202 | 
         
            +
                                    current_step = i + 1
         
     | 
| 203 | 
         
            +
                                    total_epochs = 3
         
     | 
| 204 | 
         
            +
                                    current_epoch = i / 100
         
     | 
| 205 | 
         
            +
                                    log_history = []
         
     | 
| 206 | 
         
            +
             
     | 
| 207 | 
         
            +
                                    if (i > 20):
         
     | 
| 208 | 
         
            +
                                        loss = 3 + (i - 0) * (0.5 - 3) / (300 - 0)
         
     | 
| 209 | 
         
            +
                                        log_history = [{'loss': loss}]
         
     | 
| 210 | 
         
            +
             
     | 
| 211 | 
         
            +
                                    update_training_states(
         
     | 
| 212 | 
         
            +
                                        total_steps=total_steps,
         
     | 
| 213 | 
         
            +
                                        current_step=current_step,
         
     | 
| 214 | 
         
            +
                                        total_epochs=total_epochs,
         
     | 
| 215 | 
         
            +
                                        current_epoch=current_epoch,
         
     | 
| 216 | 
         
            +
                                        log_history=log_history
         
     | 
| 217 | 
         
            +
                                    )
         
     | 
| 218 | 
         
            +
                                    time.sleep(0.1)
         
     | 
| 219 | 
         
            +
             
     | 
| 220 | 
         
            +
                                result_message = set_train_output(message)
         
     | 
| 221 | 
         
            +
                                print(result_message)
         
     | 
| 222 | 
         
            +
                                time.sleep(1)
         
     | 
| 223 | 
         
            +
                                Global.is_training = False
         
     | 
| 224 | 
         
             
                                return
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 225 | 
         | 
| 226 | 
         
            +
                            training_callbacks = [UiTrainerCallback]
         
     | 
| 227 | 
         
            +
             
     | 
| 228 | 
         
            +
                            if not os.path.exists(output_dir):
         
     | 
| 229 | 
         
            +
                                os.makedirs(output_dir)
         
     | 
| 230 | 
         
            +
             
     | 
| 231 | 
         
            +
                            with open(os.path.join(output_dir, "info.json"), 'w') as info_json_file:
         
     | 
| 232 | 
         
            +
                                dataset_name = "N/A (from text input)"
         
     | 
| 233 | 
         
            +
                                if load_dataset_from == "Data Dir":
         
     | 
| 234 | 
         
            +
                                    dataset_name = dataset_from_data_dir
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                                info = {
         
     | 
| 237 | 
         
            +
                                    'base_model': base_model_name,
         
     | 
| 238 | 
         
            +
                                    'prompt_template': template,
         
     | 
| 239 | 
         
            +
                                    'dataset_name': dataset_name,
         
     | 
| 240 | 
         
            +
                                    'dataset_rows': len(train_data),
         
     | 
| 241 | 
         
            +
                                    'trained_on_machine': socket.gethostname(),
         
     | 
| 242 | 
         
            +
                                    'timestamp': time.time(),
         
     | 
| 243 | 
         
            +
                                }
         
     | 
| 244 | 
         
            +
                                if continue_from_model:
         
     | 
| 245 | 
         
            +
                                    info['continued_from_model'] = continue_from_model
         
     | 
| 246 | 
         
            +
                                    if continue_from_checkpoint:
         
     | 
| 247 | 
         
            +
                                        info['continued_from_checkpoint'] = continue_from_checkpoint
         
     | 
| 248 | 
         
            +
             
     | 
| 249 | 
         
            +
                                if Global.version:
         
     | 
| 250 | 
         
            +
                                    info['tuner_version'] = Global.version
         
     | 
| 251 | 
         
            +
             
     | 
| 252 | 
         
            +
                                json.dump(info, info_json_file, indent=2)
         
     | 
| 253 | 
         
            +
             
     | 
| 254 | 
         
            +
                            train_output = Global.finetune_train_fn(
         
     | 
| 255 | 
         
            +
                                train_data=train_data,
         
     | 
| 256 | 
         
            +
                                callbacks=training_callbacks,
         
     | 
| 257 | 
         
            +
                                **finetune_args,
         
     | 
| 258 | 
         
             
                            )
         
     | 
| 259 | 
         | 
| 260 | 
         
            +
                            result_message = set_train_output(train_output)
         
     | 
| 261 | 
         
            +
                            print(result_message + "\n" + str(train_output))
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 262 | 
         | 
| 263 | 
         
            +
                            clear_cache()
         
     | 
| 
         | 
|
| 264 | 
         | 
| 265 | 
         
            +
                            Global.is_training = False
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 266 | 
         | 
| 267 | 
         
            +
                        except Exception as e:
         
     | 
| 268 | 
         
            +
                            traceback.print_exc()
         
     | 
| 269 | 
         
            +
                            Global.training_error_message = str(e)
         
     | 
| 270 | 
         
            +
                        finally:
         
     | 
| 271 | 
         
            +
                            Global.is_training = False
         
     | 
| 
         | 
|
| 272 | 
         | 
| 273 | 
         
            +
                    training_thread = threading.Thread(target=training)
         
     | 
| 274 | 
         
            +
                    training_thread.daemon = True
         
     | 
| 275 | 
         
            +
                    training_thread.start()
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 276 | 
         | 
| 277 | 
         
             
                except Exception as e:
         
     | 
| 278 | 
         
            +
                    Global.is_training = False
         
     | 
| 279 | 
         
            +
                    traceback.print_exc()
         
     | 
| 280 | 
         
            +
                    Global.training_error_message = str(e)
         
     | 
| 281 | 
         
            +
                finally:
         
     | 
| 282 | 
         
            +
                    Global.is_train_starting = False
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                return render_training_status()
         
     | 
| 285 | 
         
            +
             
     | 
| 286 | 
         
            +
             
     | 
| 287 | 
         
            +
            def render_training_status():
         
     | 
| 288 | 
         
            +
                if not Global.is_training:
         
     | 
| 289 | 
         
            +
                    if Global.is_train_starting:
         
     | 
| 290 | 
         
            +
                        html_content = """
         
     | 
| 291 | 
         
            +
                        <div class="progress-block">
         
     | 
| 292 | 
         
            +
                          <div class="progress-level">
         
     | 
| 293 | 
         
            +
                            <div class="progress-level-inner">
         
     | 
| 294 | 
         
            +
                              Starting...
         
     | 
| 295 | 
         
            +
                            </div>
         
     | 
| 296 | 
         
            +
                          </div>
         
     | 
| 297 | 
         
            +
                        </div>
         
     | 
| 298 | 
         
            +
                        """
         
     | 
| 299 | 
         
            +
                        return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
         
     | 
| 300 | 
         
            +
             
     | 
| 301 | 
         
            +
                    if Global.training_error_message:
         
     | 
| 302 | 
         
            +
                        html_content = f"""
         
     | 
| 303 | 
         
            +
                        <div class="progress-block is_error">
         
     | 
| 304 | 
         
            +
                          <div class="progress-level">
         
     | 
| 305 | 
         
            +
                            <div class="error">
         
     | 
| 306 | 
         
            +
                              <div class="title">
         
     | 
| 307 | 
         
            +
                                ⚠ Something went wrong
         
     | 
| 308 | 
         
            +
                              </div>
         
     | 
| 309 | 
         
            +
                              <div class="error-message">{Global.training_error_message}</div>
         
     | 
| 310 | 
         
            +
                            </div>
         
     | 
| 311 | 
         
            +
                          </div>
         
     | 
| 312 | 
         
            +
                        </div>
         
     | 
| 313 | 
         
            +
                        """
         
     | 
| 314 | 
         
            +
                        return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
         
     | 
| 315 | 
         
            +
             
     | 
| 316 | 
         
            +
                    if Global.train_output_str:
         
     | 
| 317 | 
         
            +
                        end_message = "✅ Training completed"
         
     | 
| 318 | 
         
            +
                        if Global.should_stop_training:
         
     | 
| 319 | 
         
            +
                            end_message = "🛑 Train aborted"
         
     | 
| 320 | 
         
            +
                        html_content = f"""
         
     | 
| 321 | 
         
            +
                        <div class="progress-block">
         
     | 
| 322 | 
         
            +
                          <div class="progress-level">
         
     | 
| 323 | 
         
            +
                            <div class="output">
         
     | 
| 324 | 
         
            +
                              <div class="title">
         
     | 
| 325 | 
         
            +
                                {end_message}
         
     | 
| 326 | 
         
            +
                              </div>
         
     | 
| 327 | 
         
            +
                              <div class="message">{Global.train_output_str}</div>
         
     | 
| 328 | 
         
            +
                            </div>
         
     | 
| 329 | 
         
            +
                          </div>
         
     | 
| 330 | 
         
            +
                        </div>
         
     | 
| 331 | 
         
            +
                        """
         
     | 
| 332 | 
         
            +
                        return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
         
     | 
| 333 | 
         
            +
             
     | 
| 334 | 
         
            +
                    if Global.training_status_text:
         
     | 
| 335 | 
         
            +
                        html_content = f"""
         
     | 
| 336 | 
         
            +
                        <div class="progress-block">
         
     | 
| 337 | 
         
            +
                          <div class="status">{Global.training_status_text}</div>
         
     | 
| 338 | 
         
            +
                        </div>
         
     | 
| 339 | 
         
            +
                        """
         
     | 
| 340 | 
         
            +
                        return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
         
     | 
| 341 | 
         
            +
             
     | 
| 342 | 
         
            +
                    html_content = """
         
     | 
| 343 | 
         
            +
                    <div class="progress-block">
         
     | 
| 344 | 
         
            +
                      <div class="empty-text">
         
     | 
| 345 | 
         
            +
                        Training status will be shown here
         
     | 
| 346 | 
         
            +
                      </div>
         
     | 
| 347 | 
         
            +
                    </div>
         
     | 
| 348 | 
         
            +
                    """
         
     | 
| 349 | 
         
            +
                    return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
         
     | 
| 350 | 
         
            +
             
     | 
| 351 | 
         
            +
                meta_info = []
         
     | 
| 352 | 
         
            +
                meta_info.append(
         
     | 
| 353 | 
         
            +
                    f"{Global.training_current_step}/{Global.training_total_steps} steps")
         
     | 
| 354 | 
         
            +
                current_time = time.time()
         
     | 
| 355 | 
         
            +
                time_elapsed = current_time - Global.train_started_at
         
     | 
| 356 | 
         
            +
                time_remaining = -1
         
     | 
| 357 | 
         
            +
                if Global.training_eta:
         
     | 
| 358 | 
         
            +
                    time_remaining = Global.training_eta - current_time
         
     | 
| 359 | 
         
            +
                if time_remaining >= 0:
         
     | 
| 360 | 
         
            +
                    meta_info.append(
         
     | 
| 361 | 
         
            +
                        f"{format_time(time_elapsed)}<{format_time(time_remaining)}")
         
     | 
| 362 | 
         
            +
                    meta_info.append(f"ETA: {format_timestamp(Global.training_eta)}")
         
     | 
| 363 | 
         
            +
                else:
         
     | 
| 364 | 
         
            +
                    meta_info.append(format_time(time_elapsed))
         
     | 
| 365 | 
         
            +
             
     | 
| 366 | 
         
            +
                html_content = f"""
         
     | 
| 367 | 
         
            +
                <div class="progress-block is_training">
         
     | 
| 368 | 
         
            +
                  <div class="meta-text">{' | '.join(meta_info)}</div>
         
     | 
| 369 | 
         
            +
                  <div class="progress-level">
         
     | 
| 370 | 
         
            +
                    <div class="progress-level-inner">
         
     | 
| 371 | 
         
            +
                      {Global.training_status_text} - {Global.training_progress * 100:.2f}%
         
     | 
| 372 | 
         
            +
                    </div>
         
     | 
| 373 | 
         
            +
                    <div class="progress-bar-wrap">
         
     | 
| 374 | 
         
            +
                      <div class="progress-bar" style="width: {Global.training_progress * 100:.2f}%;">
         
     | 
| 375 | 
         
            +
                      </div>
         
     | 
| 376 | 
         
            +
                    </div>
         
     | 
| 377 | 
         
            +
                  </div>
         
     | 
| 378 | 
         
            +
                </div>
         
     | 
| 379 | 
         
            +
                """
         
     | 
| 380 | 
         
            +
                return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
         
     | 
| 381 | 
         
            +
             
     | 
| 382 | 
         
            +
             
     | 
| 383 | 
         
            +
            def format_time(seconds):
         
     | 
| 384 | 
         
            +
                hours, remainder = divmod(seconds, 3600)
         
     | 
| 385 | 
         
            +
                minutes, seconds = divmod(remainder, 60)
         
     | 
| 386 | 
         
            +
                if hours == 0:
         
     | 
| 387 | 
         
            +
                    return "{:02d}:{:02d}".format(int(minutes), int(seconds))
         
     | 
| 388 | 
         
            +
                else:
         
     | 
| 389 | 
         
            +
                    return "{:02d}:{:02d}:{:02d}".format(int(hours), int(minutes), int(seconds))
         
     | 
| 390 | 
         
            +
             
     | 
| 391 | 
         
            +
             
     | 
| 392 | 
         
            +
            def format_timestamp(timestamp):
         
     | 
| 393 | 
         
            +
                dt_naive = datetime.datetime.utcfromtimestamp(timestamp)
         
     | 
| 394 | 
         
            +
                utc = pytz.UTC
         
     | 
| 395 | 
         
            +
                timezone = Config.timezone
         
     | 
| 396 | 
         
            +
                dt_aware = utc.localize(dt_naive).astimezone(timezone)
         
     | 
| 397 | 
         
            +
                now = datetime.datetime.now(timezone)
         
     | 
| 398 | 
         
            +
                delta = dt_aware.date() - now.date()
         
     | 
| 399 | 
         
            +
                if delta.days == 0:
         
     | 
| 400 | 
         
            +
                    time_str = ""
         
     | 
| 401 | 
         
            +
                elif delta.days == 1:
         
     | 
| 402 | 
         
            +
                    time_str = "tomorrow at "
         
     | 
| 403 | 
         
            +
                elif delta.days == -1:
         
     | 
| 404 | 
         
            +
                    time_str = "yesterday at "
         
     | 
| 405 | 
         
            +
                else:
         
     | 
| 406 | 
         
            +
                    time_str = dt_aware.strftime('%A, %B %d at ')
         
     | 
| 407 | 
         
            +
                time_str += dt_aware.strftime('%I:%M %p').lower()
         
     | 
| 408 | 
         
            +
                return time_str
         
     | 
    	
        llama_lora/ui/inference_ui.py
    CHANGED
    
    | 
         @@ -381,7 +381,7 @@ def inference_ui(): 
     | 
|
| 381 | 
         
             
                things_that_might_timeout = []
         
     | 
| 382 | 
         | 
| 383 | 
         
             
                with gr.Blocks() as inference_ui_blocks:
         
     | 
| 384 | 
         
            -
                    with gr.Row():
         
     | 
| 385 | 
         
             
                        with gr.Column(elem_id="inference_lora_model_group"):
         
     | 
| 386 | 
         
             
                            model_prompt_template_message = gr.Markdown(
         
     | 
| 387 | 
         
             
                                "", visible=False, elem_id="inference_lora_model_prompt_template_message")
         
     | 
| 
         @@ -402,7 +402,7 @@ def inference_ui(): 
     | 
|
| 402 | 
         
             
                        reload_selections_button.style(
         
     | 
| 403 | 
         
             
                            full_width=False,
         
     | 
| 404 | 
         
             
                            size="sm")
         
     | 
| 405 | 
         
            -
                    with gr.Row():
         
     | 
| 406 | 
         
             
                        with gr.Column():
         
     | 
| 407 | 
         
             
                            with gr.Column(elem_id="inference_prompt_box"):
         
     | 
| 408 | 
         
             
                                variable_0 = gr.Textbox(
         
     | 
| 
         | 
|
| 381 | 
         
             
                things_that_might_timeout = []
         
     | 
| 382 | 
         | 
| 383 | 
         
             
                with gr.Blocks() as inference_ui_blocks:
         
     | 
| 384 | 
         
            +
                    with gr.Row(elem_classes="disable_while_training"):
         
     | 
| 385 | 
         
             
                        with gr.Column(elem_id="inference_lora_model_group"):
         
     | 
| 386 | 
         
             
                            model_prompt_template_message = gr.Markdown(
         
     | 
| 387 | 
         
             
                                "", visible=False, elem_id="inference_lora_model_prompt_template_message")
         
     | 
| 
         | 
|
| 402 | 
         
             
                        reload_selections_button.style(
         
     | 
| 403 | 
         
             
                            full_width=False,
         
     | 
| 404 | 
         
             
                            size="sm")
         
     | 
| 405 | 
         
            +
                    with gr.Row(elem_classes="disable_while_training"):
         
     | 
| 406 | 
         
             
                        with gr.Column():
         
     | 
| 407 | 
         
             
                            with gr.Column(elem_id="inference_prompt_box"):
         
     | 
| 408 | 
         
             
                                variable_0 = gr.Textbox(
         
     | 
    	
        llama_lora/ui/main_page.py
    CHANGED
    
    | 
         @@ -18,6 +18,8 @@ def main_page(): 
     | 
|
| 18 | 
         
             
                        title=title,
         
     | 
| 19 | 
         
             
                        css=get_css_styles(),
         
     | 
| 20 | 
         
             
                ) as main_page_blocks:
         
     | 
| 
         | 
|
| 
         | 
|
| 21 | 
         
             
                    with gr.Column(elem_id="main_page_content"):
         
     | 
| 22 | 
         
             
                        with gr.Row():
         
     | 
| 23 | 
         
             
                            gr.Markdown(
         
     | 
| 
         @@ -27,7 +29,10 @@ def main_page(): 
     | 
|
| 27 | 
         
             
                                """,
         
     | 
| 28 | 
         
             
                                elem_id="page_title",
         
     | 
| 29 | 
         
             
                            )
         
     | 
| 30 | 
         
            -
                            with gr.Column( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 31 | 
         
             
                                global_base_model_select = gr.Dropdown(
         
     | 
| 32 | 
         
             
                                    label="Base Model",
         
     | 
| 33 | 
         
             
                                    elem_id="global_base_model_select",
         
     | 
| 
         @@ -99,6 +104,19 @@ def main_page(): 
     | 
|
| 99 | 
         
             
                    ]
         
     | 
| 100 | 
         
             
                )
         
     | 
| 101 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 102 | 
         
             
                main_page_blocks.load(_js=f"""
         
     | 
| 103 | 
         
             
                function () {{
         
     | 
| 104 | 
         
             
                    {popperjs_core_code()}
         
     | 
| 
         @@ -239,6 +257,12 @@ def main_page_custom_css(): 
     | 
|
| 239 | 
         
             
                }
         
     | 
| 240 | 
         
             
                */
         
     | 
| 241 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 242 | 
         
             
                .error-message, .error-message p {
         
     | 
| 243 | 
         
             
                    color: var(--error-text-color) !important;
         
     | 
| 244 | 
         
             
                }
         
     | 
| 
         @@ -261,6 +285,36 @@ def main_page_custom_css(): 
     | 
|
| 261 | 
         
             
                    max-height: unset;
         
     | 
| 262 | 
         
             
                }
         
     | 
| 263 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 264 | 
         
             
                #page_title {
         
     | 
| 265 | 
         
             
                    flex-grow: 3;
         
     | 
| 266 | 
         
             
                }
         
     | 
| 
         | 
|
| 18 | 
         
             
                        title=title,
         
     | 
| 19 | 
         
             
                        css=get_css_styles(),
         
     | 
| 20 | 
         
             
                ) as main_page_blocks:
         
     | 
| 21 | 
         
            +
                    training_indicator = gr.HTML(
         
     | 
| 22 | 
         
            +
                        "", visible=False, elem_id="training_indicator")
         
     | 
| 23 | 
         
             
                    with gr.Column(elem_id="main_page_content"):
         
     | 
| 24 | 
         
             
                        with gr.Row():
         
     | 
| 25 | 
         
             
                            gr.Markdown(
         
     | 
| 
         | 
|
| 29 | 
         
             
                                """,
         
     | 
| 30 | 
         
             
                                elem_id="page_title",
         
     | 
| 31 | 
         
             
                            )
         
     | 
| 32 | 
         
            +
                            with gr.Column(
         
     | 
| 33 | 
         
            +
                                elem_id="global_base_model_select_group",
         
     | 
| 34 | 
         
            +
                                elem_classes="disable_while_training without_message"
         
     | 
| 35 | 
         
            +
                            ):
         
     | 
| 36 | 
         
             
                                global_base_model_select = gr.Dropdown(
         
     | 
| 37 | 
         
             
                                    label="Base Model",
         
     | 
| 38 | 
         
             
                                    elem_id="global_base_model_select",
         
     | 
| 
         | 
|
| 104 | 
         
             
                    ]
         
     | 
| 105 | 
         
             
                )
         
     | 
| 106 | 
         | 
| 107 | 
         
            +
                main_page_blocks.load(
         
     | 
| 108 | 
         
            +
                    fn=lambda: gr.HTML.update(
         
     | 
| 109 | 
         
            +
                        visible=Global.is_training or Global.is_train_starting,
         
     | 
| 110 | 
         
            +
                        value=Global.is_training and "training"
         
     | 
| 111 | 
         
            +
                        or (
         
     | 
| 112 | 
         
            +
                            Global.is_train_starting and "train_starting" or ""
         
     | 
| 113 | 
         
            +
                        )
         
     | 
| 114 | 
         
            +
                    ),
         
     | 
| 115 | 
         
            +
                    inputs=None,
         
     | 
| 116 | 
         
            +
                    outputs=[training_indicator],
         
     | 
| 117 | 
         
            +
                    every=2
         
     | 
| 118 | 
         
            +
                )
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
             
                main_page_blocks.load(_js=f"""
         
     | 
| 121 | 
         
             
                function () {{
         
     | 
| 122 | 
         
             
                    {popperjs_core_code()}
         
     | 
| 
         | 
|
| 257 | 
         
             
                }
         
     | 
| 258 | 
         
             
                */
         
     | 
| 259 | 
         | 
| 260 | 
         
            +
               .hide_wrap > .wrap {
         
     | 
| 261 | 
         
            +
                   border: 0;
         
     | 
| 262 | 
         
            +
                   background: transparent;
         
     | 
| 263 | 
         
            +
                   pointer-events: none;
         
     | 
| 264 | 
         
            +
               }
         
     | 
| 265 | 
         
            +
             
     | 
| 266 | 
         
             
                .error-message, .error-message p {
         
     | 
| 267 | 
         
             
                    color: var(--error-text-color) !important;
         
     | 
| 268 | 
         
             
                }
         
     | 
| 
         | 
|
| 285 | 
         
             
                    max-height: unset;
         
     | 
| 286 | 
         
             
                }
         
     | 
| 287 | 
         | 
| 288 | 
         
            +
                #training_indicator { display: none; }
         
     | 
| 289 | 
         
            +
                #training_indicator:not(.hidden) ~ * .disable_while_training {
         
     | 
| 290 | 
         
            +
                    position: relative !important;
         
     | 
| 291 | 
         
            +
                    pointer-events: none !important;
         
     | 
| 292 | 
         
            +
                }
         
     | 
| 293 | 
         
            +
                #training_indicator:not(.hidden) ~ * .disable_while_training * {
         
     | 
| 294 | 
         
            +
                    pointer-events: none !important;
         
     | 
| 295 | 
         
            +
                }
         
     | 
| 296 | 
         
            +
                #training_indicator:not(.hidden) ~ * .disable_while_training::after {
         
     | 
| 297 | 
         
            +
                    content: "Disabled while training is in progress";
         
     | 
| 298 | 
         
            +
                    display: flex;
         
     | 
| 299 | 
         
            +
                    position: absolute !important;
         
     | 
| 300 | 
         
            +
                    z-index: 70;
         
     | 
| 301 | 
         
            +
                    top: 0;
         
     | 
| 302 | 
         
            +
                    left: 0;
         
     | 
| 303 | 
         
            +
                    right: 0;
         
     | 
| 304 | 
         
            +
                    bottom: 0;
         
     | 
| 305 | 
         
            +
                    background: var(--block-background-fill);
         
     | 
| 306 | 
         
            +
                    opacity: 0.7;
         
     | 
| 307 | 
         
            +
                    justify-content: center;
         
     | 
| 308 | 
         
            +
                    align-items: center;
         
     | 
| 309 | 
         
            +
                    color: var(--body-text-color);
         
     | 
| 310 | 
         
            +
                    font-size: var(--text-lg);
         
     | 
| 311 | 
         
            +
                    font-weight: var(--weight-bold);
         
     | 
| 312 | 
         
            +
                    text-transform: uppercase;
         
     | 
| 313 | 
         
            +
                }
         
     | 
| 314 | 
         
            +
                #training_indicator:not(.hidden) ~ * .disable_while_training.without_message::after {
         
     | 
| 315 | 
         
            +
                    content: "";
         
     | 
| 316 | 
         
            +
                }
         
     | 
| 317 | 
         
            +
             
     | 
| 318 | 
         
             
                #page_title {
         
     | 
| 319 | 
         
             
                    flex-grow: 3;
         
     | 
| 320 | 
         
             
                }
         
     | 
    	
        llama_lora/ui/tokenizer_ui.py
    CHANGED
    
    | 
         @@ -41,7 +41,7 @@ def tokenizer_ui(): 
     | 
|
| 41 | 
         
             
                things_that_might_timeout = []
         
     | 
| 42 | 
         | 
| 43 | 
         
             
                with gr.Blocks() as tokenizer_ui_blocks:
         
     | 
| 44 | 
         
            -
                    with gr.Row():
         
     | 
| 45 | 
         
             
                        with gr.Column():
         
     | 
| 46 | 
         
             
                            encoded_tokens = gr.Code(
         
     | 
| 47 | 
         
             
                                label="Encoded Tokens (JSON)",
         
     | 
| 
         | 
|
| 41 | 
         
             
                things_that_might_timeout = []
         
     | 
| 42 | 
         | 
| 43 | 
         
             
                with gr.Blocks() as tokenizer_ui_blocks:
         
     | 
| 44 | 
         
            +
                    with gr.Row(elem_classes="disable_while_training"):
         
     | 
| 45 | 
         
             
                        with gr.Column():
         
     | 
| 46 | 
         
             
                            encoded_tokens = gr.Code(
         
     | 
| 47 | 
         
             
                                label="Encoded Tokens (JSON)",
         
     | 
    	
        llama_lora/ui/trainer_callback.py
    ADDED
    
    | 
         @@ -0,0 +1,104 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import time
         
     | 
| 2 | 
         
            +
            import traceback
         
     | 
| 3 | 
         
            +
            from transformers import TrainerCallback
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
            from ..globals import Global
         
     | 
| 6 | 
         
            +
            from ..utils.eta_predictor import ETAPredictor
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def reset_training_status():
         
     | 
| 10 | 
         
            +
                Global.is_train_starting = False
         
     | 
| 11 | 
         
            +
                Global.is_training = False
         
     | 
| 12 | 
         
            +
                Global.should_stop_training = False
         
     | 
| 13 | 
         
            +
                Global.train_started_at = time.time()
         
     | 
| 14 | 
         
            +
                Global.training_error_message = None
         
     | 
| 15 | 
         
            +
                Global.training_error_detail = None
         
     | 
| 16 | 
         
            +
                Global.training_total_epochs = 1
         
     | 
| 17 | 
         
            +
                Global.training_current_epoch = 0.0
         
     | 
| 18 | 
         
            +
                Global.training_total_steps = 1
         
     | 
| 19 | 
         
            +
                Global.training_current_step = 0
         
     | 
| 20 | 
         
            +
                Global.training_progress = 0.0
         
     | 
| 21 | 
         
            +
                Global.training_log_history = []
         
     | 
| 22 | 
         
            +
                Global.training_status_text = ""
         
     | 
| 23 | 
         
            +
                Global.training_eta_predictor = ETAPredictor()
         
     | 
| 24 | 
         
            +
                Global.training_eta = None
         
     | 
| 25 | 
         
            +
                Global.train_output = None
         
     | 
| 26 | 
         
            +
                Global.train_output_str = None
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            def get_progress_text(current_epoch, total_epochs, last_loss):
         
     | 
| 30 | 
         
            +
                progress_detail = f"Epoch {current_epoch:.2f}/{total_epochs}"
         
     | 
| 31 | 
         
            +
                if last_loss is not None:
         
     | 
| 32 | 
         
            +
                    progress_detail += f", Loss: {last_loss:.4f}"
         
     | 
| 33 | 
         
            +
                return f"Training... ({progress_detail})"
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
            def set_train_output(output):
         
     | 
| 37 | 
         
            +
                end_by = 'aborted' if Global.should_stop_training else 'completed'
         
     | 
| 38 | 
         
            +
                result_message = f"Training {end_by}"
         
     | 
| 39 | 
         
            +
                Global.training_status_text = result_message
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                Global.train_output = output
         
     | 
| 42 | 
         
            +
                Global.train_output_str = str(output)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                return result_message
         
     | 
| 45 | 
         
            +
             
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
            def update_training_states(
         
     | 
| 48 | 
         
            +
                    current_step, total_steps,
         
     | 
| 49 | 
         
            +
                    current_epoch, total_epochs,
         
     | 
| 50 | 
         
            +
                    log_history):
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                Global.training_total_steps = total_steps
         
     | 
| 53 | 
         
            +
                Global.training_current_step = current_step
         
     | 
| 54 | 
         
            +
                Global.training_total_epochs = total_epochs
         
     | 
| 55 | 
         
            +
                Global.training_current_epoch = current_epoch
         
     | 
| 56 | 
         
            +
                Global.training_progress = current_step / total_steps
         
     | 
| 57 | 
         
            +
                Global.training_log_history = log_history
         
     | 
| 58 | 
         
            +
                Global.training_eta = Global.training_eta_predictor.predict_eta(current_step, total_steps)
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                last_history = None
         
     | 
| 61 | 
         
            +
                last_loss = None
         
     | 
| 62 | 
         
            +
                if len(Global.training_log_history) > 0:
         
     | 
| 63 | 
         
            +
                    last_history = log_history[-1]
         
     | 
| 64 | 
         
            +
                    last_loss = last_history.get('loss', None)
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
                Global.training_status_text = get_progress_text(
         
     | 
| 67 | 
         
            +
                    total_epochs=total_epochs,
         
     | 
| 68 | 
         
            +
                    current_epoch=current_epoch,
         
     | 
| 69 | 
         
            +
                    last_loss=last_loss,
         
     | 
| 70 | 
         
            +
                )
         
     | 
| 71 | 
         
            +
             
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            class UiTrainerCallback(TrainerCallback):
         
     | 
| 74 | 
         
            +
                def _on_progress(self, args, state, control):
         
     | 
| 75 | 
         
            +
                    if Global.should_stop_training:
         
     | 
| 76 | 
         
            +
                        control.should_training_stop = True
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                    try:
         
     | 
| 79 | 
         
            +
                        total_steps = (
         
     | 
| 80 | 
         
            +
                            state.max_steps if state.max_steps is not None
         
     | 
| 81 | 
         
            +
                            else state.num_train_epochs * state.steps_per_epoch)
         
     | 
| 82 | 
         
            +
                        current_step = state.global_step
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                        total_epochs = args.num_train_epochs
         
     | 
| 85 | 
         
            +
                        current_epoch = state.epoch
         
     | 
| 86 | 
         
            +
             
     | 
| 87 | 
         
            +
                        log_history = state.log_history
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
                        update_training_states(
         
     | 
| 90 | 
         
            +
                            total_steps=total_steps,
         
     | 
| 91 | 
         
            +
                            current_step=current_step,
         
     | 
| 92 | 
         
            +
                            total_epochs=total_epochs,
         
     | 
| 93 | 
         
            +
                            current_epoch=current_epoch,
         
     | 
| 94 | 
         
            +
                            log_history=log_history
         
     | 
| 95 | 
         
            +
                        )
         
     | 
| 96 | 
         
            +
                    except Exception as e:
         
     | 
| 97 | 
         
            +
                        print("Error occurred while updating UI status:", e)
         
     | 
| 98 | 
         
            +
                        traceback.print_exc()
         
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
            +
                def on_epoch_begin(self, args, state, control, **kwargs):
         
     | 
| 101 | 
         
            +
                    self._on_progress(args, state, control)
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                def on_step_end(self, args, state, control, **kwargs):
         
     | 
| 104 | 
         
            +
                    self._on_progress(args, state, control)
         
     | 
    	
        llama_lora/utils/eta_predictor.py
    ADDED
    
    | 
         @@ -0,0 +1,54 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import time
         
     | 
| 2 | 
         
            +
            import traceback
         
     | 
| 3 | 
         
            +
            from collections import deque
         
     | 
| 4 | 
         
            +
            from typing import Optional
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class ETAPredictor:
         
     | 
| 8 | 
         
            +
                def __init__(self, lookback_minutes: int = 180):
         
     | 
| 9 | 
         
            +
                    self.lookback_seconds = lookback_minutes * 60  # convert minutes to seconds
         
     | 
| 10 | 
         
            +
                    self.data = deque()
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                def _cleanup_old_data(self):
         
     | 
| 13 | 
         
            +
                    current_time = time.time()
         
     | 
| 14 | 
         
            +
                    while self.data and current_time - self.data[0][1] > self.lookback_seconds:
         
     | 
| 15 | 
         
            +
                        self.data.popleft()
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                def predict_eta(
         
     | 
| 18 | 
         
            +
                        self, current_step: int, total_steps: int
         
     | 
| 19 | 
         
            +
                ) -> Optional[int]:
         
     | 
| 20 | 
         
            +
                    try:
         
     | 
| 21 | 
         
            +
                        current_time = time.time()
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                        # Calculate dynamic log interval based on current logged data
         
     | 
| 24 | 
         
            +
                        log_interval = 1
         
     | 
| 25 | 
         
            +
                        if len(self.data) > 100:
         
     | 
| 26 | 
         
            +
                            log_interval = 10
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
                        # Only log data if last log is at least log_interval seconds ago
         
     | 
| 29 | 
         
            +
                        if len(self.data) < 1 or current_time - self.data[-1][1] >= log_interval:
         
     | 
| 30 | 
         
            +
                            self.data.append((current_step, current_time))
         
     | 
| 31 | 
         
            +
                            self._cleanup_old_data()
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
                        # Only predict if we have enough data
         
     | 
| 34 | 
         
            +
                        if len(self.data) < 2 or self.data[-1][1] - self.data[0][1] < 5:
         
     | 
| 35 | 
         
            +
                            return None
         
     | 
| 36 | 
         
            +
             
     | 
| 37 | 
         
            +
                        first_step, first_time = self.data[0]
         
     | 
| 38 | 
         
            +
                        steps_completed = current_step - first_step
         
     | 
| 39 | 
         
            +
                        time_elapsed = current_time - first_time
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                        if steps_completed == 0:
         
     | 
| 42 | 
         
            +
                            return None
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                        time_per_step = time_elapsed / steps_completed
         
     | 
| 45 | 
         
            +
                        steps_remaining = total_steps - current_step
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                        remaining_seconds = steps_remaining * time_per_step
         
     | 
| 48 | 
         
            +
                        eta_unix_timestamp = current_time + remaining_seconds
         
     | 
| 49 | 
         
            +
             
     | 
| 50 | 
         
            +
                        return int(eta_unix_timestamp)
         
     | 
| 51 | 
         
            +
                    except Exception as e:
         
     | 
| 52 | 
         
            +
                        print("Error predicting ETA:", e)
         
     | 
| 53 | 
         
            +
                        traceback.print_exc()
         
     | 
| 54 | 
         
            +
                        return None
         
     |