Spaces:
Runtime error
Runtime error
zetavg
commited on
make the training process async
Browse files- README.md +2 -2
- app.py +4 -0
- config.yaml.sample +2 -0
- llama_lora/config.py +7 -1
- llama_lora/globals.py +19 -0
- llama_lora/models.py +8 -0
- llama_lora/ui/finetune/finetune_ui.py +23 -12
- llama_lora/ui/finetune/script.js +26 -11
- llama_lora/ui/finetune/style.css +114 -4
- llama_lora/ui/finetune/training.py +275 -212
- llama_lora/ui/inference_ui.py +2 -2
- llama_lora/ui/main_page.py +55 -1
- llama_lora/ui/tokenizer_ui.py +1 -1
- llama_lora/ui/trainer_callback.py +104 -0
- llama_lora/utils/eta_predictor.py +54 -0
README.md
CHANGED
@@ -70,7 +70,7 @@ setup: |
|
|
70 |
# Start the app.
|
71 |
run: |
|
72 |
echo 'Starting...'
|
73 |
-
python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key="$([ -f /data/secrets/wandb_api_key ] && cat /data/secrets/wandb_api_key | tr -d '\n')" --base_model=decapoda-research/llama-7b-hf --base_model_choices='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j,databricks/dolly-v2-7b --share
|
74 |
```
|
75 |
|
76 |
Then launch a cluster to run the task:
|
@@ -100,7 +100,7 @@ When you are done, run `sky stop <cluster_name>` to stop the cluster. To termina
|
|
100 |
|
101 |
```bash
|
102 |
pip install -r requirements.lock.txt
|
103 |
-
python app.py --data_dir='./data' --base_model='decapoda-research/llama-7b-hf' --share
|
104 |
```
|
105 |
|
106 |
You will see the local and public URLs of the app in the terminal. Open the URL in your browser to use the app.
|
|
|
70 |
# Start the app.
|
71 |
run: |
|
72 |
echo 'Starting...'
|
73 |
+
python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key="$([ -f /data/secrets/wandb_api_key ] && cat /data/secrets/wandb_api_key | tr -d '\n')" --timezone='Atlantic/Reykjavik' --base_model=decapoda-research/llama-7b-hf --base_model_choices='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j,databricks/dolly-v2-7b --share
|
74 |
```
|
75 |
|
76 |
Then launch a cluster to run the task:
|
|
|
100 |
|
101 |
```bash
|
102 |
pip install -r requirements.lock.txt
|
103 |
+
python app.py --data_dir='./data' --base_model='decapoda-research/llama-7b-hf' --timezone='Atlantic/Reykjavik' --share
|
104 |
```
|
105 |
|
106 |
You will see the local and public URLs of the app in the terminal. Open the URL in your browser to use the app.
|
app.py
CHANGED
@@ -28,6 +28,7 @@ def main(
|
|
28 |
ui_dev_mode: Union[bool, None] = None,
|
29 |
wandb_api_key: Union[str, None] = None,
|
30 |
wandb_project: Union[str, None] = None,
|
|
|
31 |
):
|
32 |
'''
|
33 |
Start the LLaMA-LoRA Tuner UI.
|
@@ -76,6 +77,9 @@ def main(
|
|
76 |
if wandb_project is not None:
|
77 |
Config.default_wandb_project = wandb_project
|
78 |
|
|
|
|
|
|
|
79 |
if ui_dev_mode is not None:
|
80 |
Config.ui_dev_mode = ui_dev_mode
|
81 |
|
|
|
28 |
ui_dev_mode: Union[bool, None] = None,
|
29 |
wandb_api_key: Union[str, None] = None,
|
30 |
wandb_project: Union[str, None] = None,
|
31 |
+
timezone: Union[str, None] = None,
|
32 |
):
|
33 |
'''
|
34 |
Start the LLaMA-LoRA Tuner UI.
|
|
|
77 |
if wandb_project is not None:
|
78 |
Config.default_wandb_project = wandb_project
|
79 |
|
80 |
+
if timezone is not None:
|
81 |
+
Config.timezone = timezone
|
82 |
+
|
83 |
if ui_dev_mode is not None:
|
84 |
Config.ui_dev_mode = ui_dev_mode
|
85 |
|
config.yaml.sample
CHANGED
@@ -9,6 +9,8 @@ base_model_choices:
|
|
9 |
load_8bit: false
|
10 |
trust_remote_code: false
|
11 |
|
|
|
|
|
12 |
# UI Customization
|
13 |
# ui_title: LLM Tuner
|
14 |
# ui_emoji: 🦙🎛️
|
|
|
9 |
load_8bit: false
|
10 |
trust_remote_code: false
|
11 |
|
12 |
+
# timezone: Atlantic/Reykjavik
|
13 |
+
|
14 |
# UI Customization
|
15 |
# ui_title: LLM Tuner
|
16 |
# ui_emoji: 🦙🎛️
|
llama_lora/config.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import os
|
2 |
-
|
|
|
3 |
|
4 |
|
5 |
class Config:
|
@@ -15,6 +16,8 @@ class Config:
|
|
15 |
|
16 |
trust_remote_code: bool = False
|
17 |
|
|
|
|
|
18 |
# WandB
|
19 |
enable_wandb: Union[bool, None] = False
|
20 |
wandb_api_key: Union[str, None] = None
|
@@ -37,6 +40,9 @@ def process_config():
|
|
37 |
base_model_choices = [name.strip() for name in base_model_choices]
|
38 |
Config.base_model_choices = base_model_choices
|
39 |
|
|
|
|
|
|
|
40 |
if Config.default_base_model_name not in Config.base_model_choices:
|
41 |
Config.base_model_choices = [Config.default_base_model_name] + Config.base_model_choices
|
42 |
|
|
|
1 |
import os
|
2 |
+
import pytz
|
3 |
+
from typing import List, Union, Any
|
4 |
|
5 |
|
6 |
class Config:
|
|
|
16 |
|
17 |
trust_remote_code: bool = False
|
18 |
|
19 |
+
timezone: Any = pytz.UTC
|
20 |
+
|
21 |
# WandB
|
22 |
enable_wandb: Union[bool, None] = False
|
23 |
wandb_api_key: Union[str, None] = None
|
|
|
40 |
base_model_choices = [name.strip() for name in base_model_choices]
|
41 |
Config.base_model_choices = base_model_choices
|
42 |
|
43 |
+
if isinstance(Config.timezone, str):
|
44 |
+
Config.timezone = pytz.timezone(Config.timezone)
|
45 |
+
|
46 |
if Config.default_base_model_name not in Config.base_model_choices:
|
47 |
Config.base_model_choices = [Config.default_base_model_name] + Config.base_model_choices
|
48 |
|
llama_lora/globals.py
CHANGED
@@ -12,6 +12,7 @@ import nvidia_smi
|
|
12 |
from .dynamic_import import dynamic_import
|
13 |
from .config import Config
|
14 |
from .utils.lru_cache import LRUCache
|
|
|
15 |
|
16 |
|
17 |
class Global:
|
@@ -31,6 +32,24 @@ class Global:
|
|
31 |
# Training Control
|
32 |
should_stop_training: bool = False
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
# Generation Control
|
35 |
should_stop_generating: bool = False
|
36 |
generation_force_stopped_at: Union[float, None] = None
|
|
|
12 |
from .dynamic_import import dynamic_import
|
13 |
from .config import Config
|
14 |
from .utils.lru_cache import LRUCache
|
15 |
+
from .utils.eta_predictor import ETAPredictor
|
16 |
|
17 |
|
18 |
class Global:
|
|
|
32 |
# Training Control
|
33 |
should_stop_training: bool = False
|
34 |
|
35 |
+
# Training Status
|
36 |
+
is_train_starting: bool = False
|
37 |
+
is_training: bool = False
|
38 |
+
train_started_at: float = 0.0
|
39 |
+
training_error_message: Union[str, None] = None
|
40 |
+
training_error_detail: Union[str, None] = None
|
41 |
+
training_total_epochs: int = 0
|
42 |
+
training_current_epoch: float = 0.0
|
43 |
+
training_total_steps: int = 0
|
44 |
+
training_current_step: int = 0
|
45 |
+
training_progress: float = 0.0
|
46 |
+
training_log_history: List[Any] = []
|
47 |
+
training_status_text: str = ""
|
48 |
+
training_eta_predictor = ETAPredictor()
|
49 |
+
training_eta: Union[int, None] = None
|
50 |
+
train_output: Union[None, Any] = None
|
51 |
+
train_output_str: Union[None, str] = None
|
52 |
+
|
53 |
# Generation Control
|
54 |
should_stop_generating: bool = False
|
55 |
generation_force_stopped_at: Union[float, None] = None
|
llama_lora/models.py
CHANGED
@@ -26,6 +26,8 @@ def get_peft_model_class():
|
|
26 |
def get_new_base_model(base_model_name):
|
27 |
if Config.ui_dev_mode:
|
28 |
return
|
|
|
|
|
29 |
|
30 |
if Global.new_base_model_that_is_ready_to_be_used:
|
31 |
if Global.name_of_new_base_model_that_is_ready_to_be_used == base_model_name:
|
@@ -121,6 +123,9 @@ def get_tokenizer(base_model_name):
|
|
121 |
if Config.ui_dev_mode:
|
122 |
return
|
123 |
|
|
|
|
|
|
|
124 |
loaded_tokenizer = Global.loaded_tokenizers.get(base_model_name)
|
125 |
if loaded_tokenizer:
|
126 |
return loaded_tokenizer
|
@@ -150,6 +155,9 @@ def get_model(
|
|
150 |
if Config.ui_dev_mode:
|
151 |
return
|
152 |
|
|
|
|
|
|
|
153 |
if peft_model_name == "None":
|
154 |
peft_model_name = None
|
155 |
|
|
|
26 |
def get_new_base_model(base_model_name):
|
27 |
if Config.ui_dev_mode:
|
28 |
return
|
29 |
+
if Global.is_train_starting or Global.is_training:
|
30 |
+
raise Exception("Cannot load new base model while training.")
|
31 |
|
32 |
if Global.new_base_model_that_is_ready_to_be_used:
|
33 |
if Global.name_of_new_base_model_that_is_ready_to_be_used == base_model_name:
|
|
|
123 |
if Config.ui_dev_mode:
|
124 |
return
|
125 |
|
126 |
+
if Global.is_train_starting or Global.is_training:
|
127 |
+
raise Exception("Cannot load new base model while training.")
|
128 |
+
|
129 |
loaded_tokenizer = Global.loaded_tokenizers.get(base_model_name)
|
130 |
if loaded_tokenizer:
|
131 |
return loaded_tokenizer
|
|
|
155 |
if Config.ui_dev_mode:
|
156 |
return
|
157 |
|
158 |
+
if Global.is_train_starting or Global.is_training:
|
159 |
+
raise Exception("Cannot load new base model while training.")
|
160 |
+
|
161 |
if peft_model_name == "None":
|
162 |
peft_model_name = None
|
163 |
|
llama_lora/ui/finetune/finetune_ui.py
CHANGED
@@ -27,7 +27,8 @@ from .previewing import (
|
|
27 |
refresh_dataset_items_count,
|
28 |
)
|
29 |
from .training import (
|
30 |
-
do_train
|
|
|
31 |
)
|
32 |
|
33 |
register_css_style('finetune', relative_read_file(__file__, "style.css"))
|
@@ -770,19 +771,22 @@ def finetune_ui():
|
|
770 |
)
|
771 |
)
|
772 |
|
773 |
-
|
774 |
"Training results will be shown here.",
|
775 |
label="Train Output",
|
776 |
elem_id="finetune_training_status")
|
777 |
|
778 |
-
|
|
|
|
|
|
|
779 |
fn=do_train,
|
780 |
inputs=(dataset_inputs + finetune_args + [
|
781 |
model_name,
|
782 |
continue_from_model,
|
783 |
continue_from_checkpoint,
|
784 |
]),
|
785 |
-
outputs=
|
786 |
)
|
787 |
|
788 |
# controlled by JS, shows the confirm_abort_button
|
@@ -790,13 +794,20 @@ def finetune_ui():
|
|
790 |
confirm_abort_button.click(
|
791 |
fn=do_abort_training,
|
792 |
inputs=None, outputs=None,
|
793 |
-
cancels=[
|
794 |
-
|
795 |
-
stop_timeoutable_btn = gr.Button(
|
796 |
-
"stop not-responding elements",
|
797 |
-
elem_id="inference_stop_timeoutable_btn",
|
798 |
-
elem_classes="foot_stop_timeoutable_btn")
|
799 |
-
stop_timeoutable_btn.click(
|
800 |
-
fn=None, inputs=None, outputs=None, cancels=things_that_might_timeout)
|
801 |
|
|
|
|
|
|
|
|
|
|
|
|
|
802 |
finetune_ui_blocks.load(_js=relative_read_file(__file__, "script.js"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
refresh_dataset_items_count,
|
28 |
)
|
29 |
from .training import (
|
30 |
+
do_train,
|
31 |
+
render_training_status
|
32 |
)
|
33 |
|
34 |
register_css_style('finetune', relative_read_file(__file__, "style.css"))
|
|
|
771 |
)
|
772 |
)
|
773 |
|
774 |
+
train_status = gr.HTML(
|
775 |
"Training results will be shown here.",
|
776 |
label="Train Output",
|
777 |
elem_id="finetune_training_status")
|
778 |
|
779 |
+
training_indicator = gr.HTML(
|
780 |
+
"training_indicator", visible=False, elem_id="finetune_training_indicator")
|
781 |
+
|
782 |
+
train_start = train_btn.click(
|
783 |
fn=do_train,
|
784 |
inputs=(dataset_inputs + finetune_args + [
|
785 |
model_name,
|
786 |
continue_from_model,
|
787 |
continue_from_checkpoint,
|
788 |
]),
|
789 |
+
outputs=[train_status, training_indicator]
|
790 |
)
|
791 |
|
792 |
# controlled by JS, shows the confirm_abort_button
|
|
|
794 |
confirm_abort_button.click(
|
795 |
fn=do_abort_training,
|
796 |
inputs=None, outputs=None,
|
797 |
+
cancels=[train_start])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
798 |
|
799 |
+
training_status_updates = finetune_ui_blocks.load(
|
800 |
+
fn=render_training_status,
|
801 |
+
inputs=None,
|
802 |
+
outputs=[train_status, training_indicator],
|
803 |
+
every=0.1
|
804 |
+
)
|
805 |
finetune_ui_blocks.load(_js=relative_read_file(__file__, "script.js"))
|
806 |
+
|
807 |
+
# things_that_might_timeout.append(training_status_updates)
|
808 |
+
stop_timeoutable_btn = gr.Button(
|
809 |
+
"stop not-responding elements",
|
810 |
+
elem_id="inference_stop_timeoutable_btn",
|
811 |
+
elem_classes="foot_stop_timeoutable_btn")
|
812 |
+
stop_timeoutable_btn.click(
|
813 |
+
fn=None, inputs=None, outputs=None, cancels=things_that_might_timeout)
|
llama_lora/ui/finetune/script.js
CHANGED
@@ -130,10 +130,10 @@ function finetune_ui_blocks_js() {
|
|
130 |
|
131 |
// Show/hide start and stop button base on the state.
|
132 |
setTimeout(function () {
|
133 |
-
// Make the '#
|
134 |
-
if (!document.querySelector('#
|
135 |
-
|
136 |
-
}
|
137 |
|
138 |
setTimeout(function () {
|
139 |
let resetStopButtonTimer;
|
@@ -156,11 +156,20 @@ function finetune_ui_blocks_js() {
|
|
156 |
document.getElementById('finetune_confirm_stop_btn').style.display =
|
157 |
'block';
|
158 |
});
|
159 |
-
const
|
160 |
-
|
|
|
|
|
|
|
161 |
);
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
if (resetStopButtonTimer) clearTimeout(resetStopButtonTimer);
|
165 |
document.getElementById('finetune_start_btn').style.display = 'block';
|
166 |
document.getElementById('finetune_stop_btn').style.display = 'none';
|
@@ -173,13 +182,19 @@ function finetune_ui_blocks_js() {
|
|
173 |
'none';
|
174 |
}
|
175 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
new MutationObserver(function (mutationsList, observer) {
|
177 |
-
|
178 |
-
}).observe(
|
179 |
attributes: true,
|
180 |
attributeFilter: ['class'],
|
181 |
});
|
182 |
-
|
183 |
}, 500);
|
184 |
}, 0);
|
185 |
}
|
|
|
130 |
|
131 |
// Show/hide start and stop button base on the state.
|
132 |
setTimeout(function () {
|
133 |
+
// Make the '#finetune_training_indicator > .wrap' element appear
|
134 |
+
// if (!document.querySelector('#finetune_training_indicator > .wrap')) {
|
135 |
+
// document.getElementById('finetune_confirm_stop_btn').click();
|
136 |
+
// }
|
137 |
|
138 |
setTimeout(function () {
|
139 |
let resetStopButtonTimer;
|
|
|
156 |
document.getElementById('finetune_confirm_stop_btn').style.display =
|
157 |
'block';
|
158 |
});
|
159 |
+
// const training_indicator_wrap_element = document.querySelector(
|
160 |
+
// '#finetune_training_indicator > .wrap'
|
161 |
+
// );
|
162 |
+
const training_indicator_element = document.querySelector(
|
163 |
+
'#finetune_training_indicator'
|
164 |
);
|
165 |
+
let isTraining = undefined;
|
166 |
+
function handle_training_indicator_change() {
|
167 |
+
// const wrapperHidden = Array.from(training_indicator_wrap_element.classList).includes('hide');
|
168 |
+
const hidden = Array.from(training_indicator_element.classList).includes('hidden');
|
169 |
+
const newIsTraining = !(/* wrapperHidden && */ hidden);
|
170 |
+
if (newIsTraining === isTraining) return;
|
171 |
+
isTraining = newIsTraining;
|
172 |
+
if (!isTraining) {
|
173 |
if (resetStopButtonTimer) clearTimeout(resetStopButtonTimer);
|
174 |
document.getElementById('finetune_start_btn').style.display = 'block';
|
175 |
document.getElementById('finetune_stop_btn').style.display = 'none';
|
|
|
182 |
'none';
|
183 |
}
|
184 |
}
|
185 |
+
// new MutationObserver(function (mutationsList, observer) {
|
186 |
+
// handle_training_indicator_change();
|
187 |
+
// }).observe(training_indicator_wrap_element, {
|
188 |
+
// attributes: true,
|
189 |
+
// attributeFilter: ['class'],
|
190 |
+
// });
|
191 |
new MutationObserver(function (mutationsList, observer) {
|
192 |
+
handle_training_indicator_change();
|
193 |
+
}).observe(training_indicator_element, {
|
194 |
attributes: true,
|
195 |
attributeFilter: ['class'],
|
196 |
});
|
197 |
+
handle_training_indicator_change();
|
198 |
}, 500);
|
199 |
}, 0);
|
200 |
}
|
llama_lora/ui/finetune/style.css
CHANGED
@@ -255,8 +255,118 @@
|
|
255 |
display: none;
|
256 |
}
|
257 |
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
255 |
display: none;
|
256 |
}
|
257 |
|
258 |
+
#finetune_training_status > .wrap {
|
259 |
+
border: 0;
|
260 |
+
background: transparent;
|
261 |
+
pointer-events: none;
|
262 |
+
top: 0;
|
263 |
+
bottom: 0;
|
264 |
+
left: 0;
|
265 |
+
right: 0;
|
266 |
+
}
|
267 |
+
#finetune_training_status > .wrap .meta-text-center {
|
268 |
+
transform: none !important;
|
269 |
+
}
|
270 |
+
|
271 |
+
#finetune_training_status .progress-block {
|
272 |
+
min-height: 100px;
|
273 |
+
display: flex;
|
274 |
+
justify-content: center;
|
275 |
+
align-items: center;
|
276 |
+
background: var(--panel-background-fill);
|
277 |
+
border-radius: var(--radius-lg);
|
278 |
+
border: var(--block-border-width) solid var(--border-color-primary);
|
279 |
+
padding: var(--block-padding);
|
280 |
+
}
|
281 |
+
#finetune_training_status .progress-block.is_training {
|
282 |
+
min-height: 160px;
|
283 |
+
}
|
284 |
+
#finetune_training_status .progress-block .empty-text {
|
285 |
+
text-transform: uppercase;
|
286 |
+
font-weight: 700;
|
287 |
+
font-size: 120%;
|
288 |
+
opacity: 0.12;
|
289 |
+
}
|
290 |
+
#finetune_training_status .progress-block .meta-text {
|
291 |
+
position: absolute;
|
292 |
+
top: 0;
|
293 |
+
right: 0;
|
294 |
+
z-index: var(--layer-2);
|
295 |
+
padding: var(--size-1) var(--size-2);
|
296 |
+
font-size: var(--text-sm);
|
297 |
+
font-family: var(--font-mono);
|
298 |
+
}
|
299 |
+
#finetune_training_status .progress-block .status {
|
300 |
+
white-space: pre-wrap;
|
301 |
+
}
|
302 |
+
#finetune_training_status .progress-block .progress-level {
|
303 |
+
display: flex;
|
304 |
+
flex-direction: column;
|
305 |
+
align-items: center;
|
306 |
+
z-index: var(--layer-2);
|
307 |
+
width: var(--size-full);
|
308 |
+
}
|
309 |
+
#finetune_training_status .progress-block .progress-level-inner {
|
310 |
+
margin: var(--size-2) auto;
|
311 |
+
color: var(--body-text-color);
|
312 |
+
font-size: var(--text-sm);
|
313 |
+
font-family: var(--font-mono);
|
314 |
+
}
|
315 |
+
#finetune_training_status .progress-block .progress-bar-wrap {
|
316 |
+
border: 1px solid var(--border-color-primary);
|
317 |
+
background: var(--background-fill-primary);
|
318 |
+
width: 55.5%;
|
319 |
+
height: var(--size-4);
|
320 |
+
}
|
321 |
+
#finetune_training_status .progress-block .progress-bar {
|
322 |
+
transform-origin: left;
|
323 |
+
background-color: var(--loader-color);
|
324 |
+
width: var(--size-full);
|
325 |
+
height: var(--size-full);
|
326 |
+
transition: all 150ms ease 0s;
|
327 |
}
|
328 |
+
|
329 |
+
#finetune_training_status .progress-block .output {
|
330 |
+
display: flex;
|
331 |
+
flex-direction: column;
|
332 |
+
justify-content: center;
|
333 |
+
align-items: center;
|
334 |
+
}
|
335 |
+
#finetune_training_status .progress-block .output .title {
|
336 |
+
padding: var(--size-1) var(--size-3);
|
337 |
+
font-weight: var(--weight-bold);
|
338 |
+
font-size: var(--text-lg);
|
339 |
+
line-height: var(--line-xs);
|
340 |
+
}
|
341 |
+
#finetune_training_status .progress-block .output .message {
|
342 |
+
padding: var(--size-1) var(--size-3);
|
343 |
+
color: var(--body-text-color) !important;
|
344 |
+
font-family: var(--font-mono);
|
345 |
+
white-space: pre-wrap;
|
346 |
+
}
|
347 |
+
|
348 |
+
#finetune_training_status .progress-block .error {
|
349 |
+
display: flex;
|
350 |
+
flex-direction: column;
|
351 |
+
justify-content: center;
|
352 |
+
align-items: center;
|
353 |
+
}
|
354 |
+
#finetune_training_status .progress-block .error .title {
|
355 |
+
padding: var(--size-1) var(--size-3);
|
356 |
+
color: var(--color-red-500);
|
357 |
+
font-weight: var(--weight-bold);
|
358 |
+
font-size: var(--text-lg);
|
359 |
+
line-height: var(--line-xs);
|
360 |
+
}
|
361 |
+
#finetune_training_status .progress-block .error .error-message {
|
362 |
+
padding: var(--size-1) var(--size-3);
|
363 |
+
color: var(--body-text-color) !important;
|
364 |
+
font-family: var(--font-mono);
|
365 |
+
white-space: pre-wrap;
|
366 |
+
}
|
367 |
+
#finetune_training_status .progress-block.is_error {
|
368 |
+
/* background: var(--error-background-fill) !important; */
|
369 |
+
border: 1px solid var(--error-border-color) !important;
|
370 |
+
}
|
371 |
+
|
372 |
+
#finetune_training_indicator { display: none; }
|
llama_lora/ui/finetune/training.py
CHANGED
@@ -1,24 +1,26 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
import time
|
|
|
|
|
|
|
|
|
|
|
4 |
import gradio as gr
|
5 |
-
import math
|
6 |
|
7 |
-
from transformers import TrainerCallback
|
8 |
from huggingface_hub import try_to_load_from_cache, snapshot_download
|
9 |
|
10 |
from ...config import Config
|
11 |
from ...globals import Global
|
12 |
from ...models import clear_cache, unload_models
|
13 |
from ...utils.prompter import Prompter
|
|
|
|
|
|
|
|
|
14 |
|
15 |
from .data_processing import get_data_from_input
|
16 |
|
17 |
-
should_training_progress_track_tqdm = True
|
18 |
-
|
19 |
-
if Global.gpu_total_cores is not None and Global.gpu_total_cores > 2560:
|
20 |
-
should_training_progress_track_tqdm = False
|
21 |
-
|
22 |
|
23 |
def do_train(
|
24 |
# Dataset
|
@@ -55,8 +57,14 @@ def do_train(
|
|
55 |
model_name,
|
56 |
continue_from_model,
|
57 |
continue_from_checkpoint,
|
58 |
-
progress=gr.Progress(track_tqdm=
|
59 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
try:
|
61 |
base_model_name = Global.base_model_name
|
62 |
tokenizer_name = Global.tokenizer_name or Global.base_model_name
|
@@ -115,18 +123,47 @@ def do_train(
|
|
115 |
raise ValueError(
|
116 |
f"The output directory already exists and is not empty. ({output_dir})")
|
117 |
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
prompter = Prompter(template)
|
128 |
-
# variable_names = prompter.get_variable_names()
|
129 |
-
|
130 |
data = get_data_from_input(
|
131 |
load_dataset_from=load_dataset_from,
|
132 |
dataset_text=dataset_text,
|
@@ -138,208 +175,234 @@ def do_train(
|
|
138 |
prompter=prompter
|
139 |
)
|
140 |
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
|
|
|
|
|
|
184 |
return
|
185 |
-
epochs = 3
|
186 |
-
epoch = i / 100
|
187 |
-
last_loss = None
|
188 |
-
if (i > 20):
|
189 |
-
last_loss = 3 + (i - 0) * (0.5 - 3) / (300 - 0)
|
190 |
-
|
191 |
-
progress(
|
192 |
-
(i, 300),
|
193 |
-
desc="(Simulate) " +
|
194 |
-
get_progress_text(epoch, epochs, last_loss)
|
195 |
-
)
|
196 |
-
|
197 |
-
time.sleep(0.1)
|
198 |
-
|
199 |
-
time.sleep(2)
|
200 |
-
return message
|
201 |
-
|
202 |
-
if not should_training_progress_track_tqdm:
|
203 |
-
progress(
|
204 |
-
0, desc=f"Preparing model {base_model_name} for training...")
|
205 |
-
|
206 |
-
log_history = []
|
207 |
-
|
208 |
-
class UiTrainerCallback(TrainerCallback):
|
209 |
-
def _on_progress(self, args, state, control):
|
210 |
-
nonlocal log_history
|
211 |
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
)
|
231 |
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
def on_step_end(self, args, state, control, **kwargs):
|
236 |
-
self._on_progress(args, state, control)
|
237 |
-
|
238 |
-
training_callbacks = [UiTrainerCallback]
|
239 |
-
|
240 |
-
Global.should_stop_training = False
|
241 |
-
|
242 |
-
# Do not let other tqdm iterations interfere the progress reporting after training starts.
|
243 |
-
# progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
|
244 |
|
245 |
-
|
246 |
-
os.makedirs(output_dir)
|
247 |
|
248 |
-
|
249 |
-
dataset_name = "N/A (from text input)"
|
250 |
-
if load_dataset_from == "Data Dir":
|
251 |
-
dataset_name = dataset_from_data_dir
|
252 |
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
'timestamp': time.time(),
|
259 |
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
# 'micro_batch_size': micro_batch_size,
|
265 |
-
# 'gradient_accumulation_steps': gradient_accumulation_steps,
|
266 |
-
# 'epochs': epochs,
|
267 |
-
# 'learning_rate': learning_rate,
|
268 |
-
|
269 |
-
# 'evaluate_data_count': evaluate_data_count,
|
270 |
-
|
271 |
-
# 'lora_r': lora_r,
|
272 |
-
# 'lora_alpha': lora_alpha,
|
273 |
-
# 'lora_dropout': lora_dropout,
|
274 |
-
# 'lora_target_modules': lora_target_modules,
|
275 |
-
}
|
276 |
-
if continue_from_model:
|
277 |
-
info['continued_from_model'] = continue_from_model
|
278 |
-
if continue_from_checkpoint:
|
279 |
-
info['continued_from_checkpoint'] = continue_from_checkpoint
|
280 |
-
|
281 |
-
if Global.version:
|
282 |
-
info['tuner_version'] = Global.version
|
283 |
-
|
284 |
-
json.dump(info, info_json_file, indent=2)
|
285 |
-
|
286 |
-
if not should_training_progress_track_tqdm:
|
287 |
-
progress(0, desc="Train starting...")
|
288 |
-
|
289 |
-
wandb_group = template
|
290 |
-
wandb_tags = [f"template:{template}"]
|
291 |
-
if load_dataset_from == "Data Dir" and dataset_from_data_dir:
|
292 |
-
wandb_group += f"/{dataset_from_data_dir}"
|
293 |
-
wandb_tags.append(f"dataset:{dataset_from_data_dir}")
|
294 |
-
|
295 |
-
train_output = Global.finetune_train_fn(
|
296 |
-
base_model=base_model_name,
|
297 |
-
tokenizer=tokenizer_name,
|
298 |
-
output_dir=output_dir,
|
299 |
-
train_data=train_data,
|
300 |
-
# 128, # batch_size (is not used, use gradient_accumulation_steps instead)
|
301 |
-
micro_batch_size=micro_batch_size,
|
302 |
-
gradient_accumulation_steps=gradient_accumulation_steps,
|
303 |
-
num_train_epochs=epochs,
|
304 |
-
learning_rate=learning_rate,
|
305 |
-
cutoff_len=max_seq_length,
|
306 |
-
val_set_size=evaluate_data_count,
|
307 |
-
lora_r=lora_r,
|
308 |
-
lora_alpha=lora_alpha,
|
309 |
-
lora_dropout=lora_dropout,
|
310 |
-
lora_target_modules=lora_target_modules,
|
311 |
-
lora_modules_to_save=lora_modules_to_save,
|
312 |
-
train_on_inputs=train_on_inputs,
|
313 |
-
load_in_8bit=load_in_8bit,
|
314 |
-
fp16=fp16,
|
315 |
-
bf16=bf16,
|
316 |
-
gradient_checkpointing=gradient_checkpointing,
|
317 |
-
group_by_length=False,
|
318 |
-
resume_from_checkpoint=resume_from_checkpoint_param,
|
319 |
-
save_steps=save_steps,
|
320 |
-
save_total_limit=save_total_limit,
|
321 |
-
logging_steps=logging_steps,
|
322 |
-
additional_training_arguments=additional_training_arguments,
|
323 |
-
additional_lora_config=additional_lora_config,
|
324 |
-
callbacks=training_callbacks,
|
325 |
-
wandb_api_key=Config.wandb_api_key,
|
326 |
-
wandb_project=Config.default_wandb_project if Config.enable_wandb else None,
|
327 |
-
wandb_group=wandb_group,
|
328 |
-
wandb_run_name=model_name,
|
329 |
-
wandb_tags=wandb_tags
|
330 |
-
)
|
331 |
-
|
332 |
-
logs_str = "\n".join([json.dumps(log)
|
333 |
-
for log in log_history]) or "None"
|
334 |
-
|
335 |
-
result_message = f"Training ended:\n{str(train_output)}"
|
336 |
-
print(result_message)
|
337 |
-
# result_message += f"\n\nLogs:\n{logs_str}"
|
338 |
-
|
339 |
-
clear_cache()
|
340 |
-
|
341 |
-
return result_message
|
342 |
|
343 |
except Exception as e:
|
344 |
-
|
345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import json
|
3 |
import time
|
4 |
+
import datetime
|
5 |
+
import pytz
|
6 |
+
import socket
|
7 |
+
import threading
|
8 |
+
import traceback
|
9 |
import gradio as gr
|
|
|
10 |
|
|
|
11 |
from huggingface_hub import try_to_load_from_cache, snapshot_download
|
12 |
|
13 |
from ...config import Config
|
14 |
from ...globals import Global
|
15 |
from ...models import clear_cache, unload_models
|
16 |
from ...utils.prompter import Prompter
|
17 |
+
from ..trainer_callback import (
|
18 |
+
UiTrainerCallback, reset_training_status,
|
19 |
+
update_training_states, set_train_output
|
20 |
+
)
|
21 |
|
22 |
from .data_processing import get_data_from_input
|
23 |
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def do_train(
|
26 |
# Dataset
|
|
|
57 |
model_name,
|
58 |
continue_from_model,
|
59 |
continue_from_checkpoint,
|
60 |
+
progress=gr.Progress(track_tqdm=False),
|
61 |
):
|
62 |
+
if Global.is_training:
|
63 |
+
return render_training_status()
|
64 |
+
|
65 |
+
reset_training_status()
|
66 |
+
Global.is_train_starting = True
|
67 |
+
|
68 |
try:
|
69 |
base_model_name = Global.base_model_name
|
70 |
tokenizer_name = Global.tokenizer_name or Global.base_model_name
|
|
|
123 |
raise ValueError(
|
124 |
f"The output directory already exists and is not empty. ({output_dir})")
|
125 |
|
126 |
+
wandb_group = template
|
127 |
+
wandb_tags = [f"template:{template}"]
|
128 |
+
if load_dataset_from == "Data Dir" and dataset_from_data_dir:
|
129 |
+
wandb_group += f"/{dataset_from_data_dir}"
|
130 |
+
wandb_tags.append(f"dataset:{dataset_from_data_dir}")
|
131 |
|
132 |
+
finetune_args = {
|
133 |
+
'base_model': base_model_name,
|
134 |
+
'tokenizer': tokenizer_name,
|
135 |
+
'output_dir': output_dir,
|
136 |
+
'micro_batch_size': micro_batch_size,
|
137 |
+
'gradient_accumulation_steps': gradient_accumulation_steps,
|
138 |
+
'num_train_epochs': epochs,
|
139 |
+
'learning_rate': learning_rate,
|
140 |
+
'cutoff_len': max_seq_length,
|
141 |
+
'val_set_size': evaluate_data_count,
|
142 |
+
'lora_r': lora_r,
|
143 |
+
'lora_alpha': lora_alpha,
|
144 |
+
'lora_dropout': lora_dropout,
|
145 |
+
'lora_target_modules': lora_target_modules,
|
146 |
+
'lora_modules_to_save': lora_modules_to_save,
|
147 |
+
'train_on_inputs': train_on_inputs,
|
148 |
+
'load_in_8bit': load_in_8bit,
|
149 |
+
'fp16': fp16,
|
150 |
+
'bf16': bf16,
|
151 |
+
'gradient_checkpointing': gradient_checkpointing,
|
152 |
+
'group_by_length': False,
|
153 |
+
'resume_from_checkpoint': resume_from_checkpoint_param,
|
154 |
+
'save_steps': save_steps,
|
155 |
+
'save_total_limit': save_total_limit,
|
156 |
+
'logging_steps': logging_steps,
|
157 |
+
'additional_training_arguments': additional_training_arguments,
|
158 |
+
'additional_lora_config': additional_lora_config,
|
159 |
+
'wandb_api_key': Config.wandb_api_key,
|
160 |
+
'wandb_project': Config.default_wandb_project if Config.enable_wandb else None,
|
161 |
+
'wandb_group': wandb_group,
|
162 |
+
'wandb_run_name': model_name,
|
163 |
+
'wandb_tags': wandb_tags
|
164 |
+
}
|
165 |
|
166 |
prompter = Prompter(template)
|
|
|
|
|
167 |
data = get_data_from_input(
|
168 |
load_dataset_from=load_dataset_from,
|
169 |
dataset_text=dataset_text,
|
|
|
175 |
prompter=prompter
|
176 |
)
|
177 |
|
178 |
+
def training():
|
179 |
+
Global.is_training = True
|
180 |
+
|
181 |
+
try:
|
182 |
+
# Need RAM for training
|
183 |
+
unload_models()
|
184 |
+
Global.new_base_model_that_is_ready_to_be_used = None
|
185 |
+
Global.name_of_new_base_model_that_is_ready_to_be_used = None
|
186 |
+
clear_cache()
|
187 |
+
|
188 |
+
train_data = prompter.get_train_data_from_dataset(data)
|
189 |
+
|
190 |
+
if Config.ui_dev_mode:
|
191 |
+
message = "Currently in UI dev mode, not doing the actual training."
|
192 |
+
message += f"\n\nArgs: {json.dumps(finetune_args, indent=2)}"
|
193 |
+
message += f"\n\nTrain data (first 5):\n{json.dumps(train_data[:5], indent=2)}"
|
194 |
+
|
195 |
+
print(message)
|
196 |
+
|
197 |
+
total_steps = 300
|
198 |
+
for i in range(300):
|
199 |
+
if (Global.should_stop_training):
|
200 |
+
break
|
201 |
+
|
202 |
+
current_step = i + 1
|
203 |
+
total_epochs = 3
|
204 |
+
current_epoch = i / 100
|
205 |
+
log_history = []
|
206 |
+
|
207 |
+
if (i > 20):
|
208 |
+
loss = 3 + (i - 0) * (0.5 - 3) / (300 - 0)
|
209 |
+
log_history = [{'loss': loss}]
|
210 |
+
|
211 |
+
update_training_states(
|
212 |
+
total_steps=total_steps,
|
213 |
+
current_step=current_step,
|
214 |
+
total_epochs=total_epochs,
|
215 |
+
current_epoch=current_epoch,
|
216 |
+
log_history=log_history
|
217 |
+
)
|
218 |
+
time.sleep(0.1)
|
219 |
+
|
220 |
+
result_message = set_train_output(message)
|
221 |
+
print(result_message)
|
222 |
+
time.sleep(1)
|
223 |
+
Global.is_training = False
|
224 |
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
+
training_callbacks = [UiTrainerCallback]
|
227 |
+
|
228 |
+
if not os.path.exists(output_dir):
|
229 |
+
os.makedirs(output_dir)
|
230 |
+
|
231 |
+
with open(os.path.join(output_dir, "info.json"), 'w') as info_json_file:
|
232 |
+
dataset_name = "N/A (from text input)"
|
233 |
+
if load_dataset_from == "Data Dir":
|
234 |
+
dataset_name = dataset_from_data_dir
|
235 |
+
|
236 |
+
info = {
|
237 |
+
'base_model': base_model_name,
|
238 |
+
'prompt_template': template,
|
239 |
+
'dataset_name': dataset_name,
|
240 |
+
'dataset_rows': len(train_data),
|
241 |
+
'trained_on_machine': socket.gethostname(),
|
242 |
+
'timestamp': time.time(),
|
243 |
+
}
|
244 |
+
if continue_from_model:
|
245 |
+
info['continued_from_model'] = continue_from_model
|
246 |
+
if continue_from_checkpoint:
|
247 |
+
info['continued_from_checkpoint'] = continue_from_checkpoint
|
248 |
+
|
249 |
+
if Global.version:
|
250 |
+
info['tuner_version'] = Global.version
|
251 |
+
|
252 |
+
json.dump(info, info_json_file, indent=2)
|
253 |
+
|
254 |
+
train_output = Global.finetune_train_fn(
|
255 |
+
train_data=train_data,
|
256 |
+
callbacks=training_callbacks,
|
257 |
+
**finetune_args,
|
258 |
)
|
259 |
|
260 |
+
result_message = set_train_output(train_output)
|
261 |
+
print(result_message + "\n" + str(train_output))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
+
clear_cache()
|
|
|
264 |
|
265 |
+
Global.is_training = False
|
|
|
|
|
|
|
266 |
|
267 |
+
except Exception as e:
|
268 |
+
traceback.print_exc()
|
269 |
+
Global.training_error_message = str(e)
|
270 |
+
finally:
|
271 |
+
Global.is_training = False
|
|
|
272 |
|
273 |
+
training_thread = threading.Thread(target=training)
|
274 |
+
training_thread.daemon = True
|
275 |
+
training_thread.start()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
|
277 |
except Exception as e:
|
278 |
+
Global.is_training = False
|
279 |
+
traceback.print_exc()
|
280 |
+
Global.training_error_message = str(e)
|
281 |
+
finally:
|
282 |
+
Global.is_train_starting = False
|
283 |
+
|
284 |
+
return render_training_status()
|
285 |
+
|
286 |
+
|
287 |
+
def render_training_status():
|
288 |
+
if not Global.is_training:
|
289 |
+
if Global.is_train_starting:
|
290 |
+
html_content = """
|
291 |
+
<div class="progress-block">
|
292 |
+
<div class="progress-level">
|
293 |
+
<div class="progress-level-inner">
|
294 |
+
Starting...
|
295 |
+
</div>
|
296 |
+
</div>
|
297 |
+
</div>
|
298 |
+
"""
|
299 |
+
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
|
300 |
+
|
301 |
+
if Global.training_error_message:
|
302 |
+
html_content = f"""
|
303 |
+
<div class="progress-block is_error">
|
304 |
+
<div class="progress-level">
|
305 |
+
<div class="error">
|
306 |
+
<div class="title">
|
307 |
+
⚠ Something went wrong
|
308 |
+
</div>
|
309 |
+
<div class="error-message">{Global.training_error_message}</div>
|
310 |
+
</div>
|
311 |
+
</div>
|
312 |
+
</div>
|
313 |
+
"""
|
314 |
+
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
|
315 |
+
|
316 |
+
if Global.train_output_str:
|
317 |
+
end_message = "✅ Training completed"
|
318 |
+
if Global.should_stop_training:
|
319 |
+
end_message = "🛑 Train aborted"
|
320 |
+
html_content = f"""
|
321 |
+
<div class="progress-block">
|
322 |
+
<div class="progress-level">
|
323 |
+
<div class="output">
|
324 |
+
<div class="title">
|
325 |
+
{end_message}
|
326 |
+
</div>
|
327 |
+
<div class="message">{Global.train_output_str}</div>
|
328 |
+
</div>
|
329 |
+
</div>
|
330 |
+
</div>
|
331 |
+
"""
|
332 |
+
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
|
333 |
+
|
334 |
+
if Global.training_status_text:
|
335 |
+
html_content = f"""
|
336 |
+
<div class="progress-block">
|
337 |
+
<div class="status">{Global.training_status_text}</div>
|
338 |
+
</div>
|
339 |
+
"""
|
340 |
+
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
|
341 |
+
|
342 |
+
html_content = """
|
343 |
+
<div class="progress-block">
|
344 |
+
<div class="empty-text">
|
345 |
+
Training status will be shown here
|
346 |
+
</div>
|
347 |
+
</div>
|
348 |
+
"""
|
349 |
+
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
|
350 |
+
|
351 |
+
meta_info = []
|
352 |
+
meta_info.append(
|
353 |
+
f"{Global.training_current_step}/{Global.training_total_steps} steps")
|
354 |
+
current_time = time.time()
|
355 |
+
time_elapsed = current_time - Global.train_started_at
|
356 |
+
time_remaining = -1
|
357 |
+
if Global.training_eta:
|
358 |
+
time_remaining = Global.training_eta - current_time
|
359 |
+
if time_remaining >= 0:
|
360 |
+
meta_info.append(
|
361 |
+
f"{format_time(time_elapsed)}<{format_time(time_remaining)}")
|
362 |
+
meta_info.append(f"ETA: {format_timestamp(Global.training_eta)}")
|
363 |
+
else:
|
364 |
+
meta_info.append(format_time(time_elapsed))
|
365 |
+
|
366 |
+
html_content = f"""
|
367 |
+
<div class="progress-block is_training">
|
368 |
+
<div class="meta-text">{' | '.join(meta_info)}</div>
|
369 |
+
<div class="progress-level">
|
370 |
+
<div class="progress-level-inner">
|
371 |
+
{Global.training_status_text} - {Global.training_progress * 100:.2f}%
|
372 |
+
</div>
|
373 |
+
<div class="progress-bar-wrap">
|
374 |
+
<div class="progress-bar" style="width: {Global.training_progress * 100:.2f}%;">
|
375 |
+
</div>
|
376 |
+
</div>
|
377 |
+
</div>
|
378 |
+
</div>
|
379 |
+
"""
|
380 |
+
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
|
381 |
+
|
382 |
+
|
383 |
+
def format_time(seconds):
|
384 |
+
hours, remainder = divmod(seconds, 3600)
|
385 |
+
minutes, seconds = divmod(remainder, 60)
|
386 |
+
if hours == 0:
|
387 |
+
return "{:02d}:{:02d}".format(int(minutes), int(seconds))
|
388 |
+
else:
|
389 |
+
return "{:02d}:{:02d}:{:02d}".format(int(hours), int(minutes), int(seconds))
|
390 |
+
|
391 |
+
|
392 |
+
def format_timestamp(timestamp):
|
393 |
+
dt_naive = datetime.datetime.utcfromtimestamp(timestamp)
|
394 |
+
utc = pytz.UTC
|
395 |
+
timezone = Config.timezone
|
396 |
+
dt_aware = utc.localize(dt_naive).astimezone(timezone)
|
397 |
+
now = datetime.datetime.now(timezone)
|
398 |
+
delta = dt_aware.date() - now.date()
|
399 |
+
if delta.days == 0:
|
400 |
+
time_str = ""
|
401 |
+
elif delta.days == 1:
|
402 |
+
time_str = "tomorrow at "
|
403 |
+
elif delta.days == -1:
|
404 |
+
time_str = "yesterday at "
|
405 |
+
else:
|
406 |
+
time_str = dt_aware.strftime('%A, %B %d at ')
|
407 |
+
time_str += dt_aware.strftime('%I:%M %p').lower()
|
408 |
+
return time_str
|
llama_lora/ui/inference_ui.py
CHANGED
@@ -381,7 +381,7 @@ def inference_ui():
|
|
381 |
things_that_might_timeout = []
|
382 |
|
383 |
with gr.Blocks() as inference_ui_blocks:
|
384 |
-
with gr.Row():
|
385 |
with gr.Column(elem_id="inference_lora_model_group"):
|
386 |
model_prompt_template_message = gr.Markdown(
|
387 |
"", visible=False, elem_id="inference_lora_model_prompt_template_message")
|
@@ -402,7 +402,7 @@ def inference_ui():
|
|
402 |
reload_selections_button.style(
|
403 |
full_width=False,
|
404 |
size="sm")
|
405 |
-
with gr.Row():
|
406 |
with gr.Column():
|
407 |
with gr.Column(elem_id="inference_prompt_box"):
|
408 |
variable_0 = gr.Textbox(
|
|
|
381 |
things_that_might_timeout = []
|
382 |
|
383 |
with gr.Blocks() as inference_ui_blocks:
|
384 |
+
with gr.Row(elem_classes="disable_while_training"):
|
385 |
with gr.Column(elem_id="inference_lora_model_group"):
|
386 |
model_prompt_template_message = gr.Markdown(
|
387 |
"", visible=False, elem_id="inference_lora_model_prompt_template_message")
|
|
|
402 |
reload_selections_button.style(
|
403 |
full_width=False,
|
404 |
size="sm")
|
405 |
+
with gr.Row(elem_classes="disable_while_training"):
|
406 |
with gr.Column():
|
407 |
with gr.Column(elem_id="inference_prompt_box"):
|
408 |
variable_0 = gr.Textbox(
|
llama_lora/ui/main_page.py
CHANGED
@@ -18,6 +18,8 @@ def main_page():
|
|
18 |
title=title,
|
19 |
css=get_css_styles(),
|
20 |
) as main_page_blocks:
|
|
|
|
|
21 |
with gr.Column(elem_id="main_page_content"):
|
22 |
with gr.Row():
|
23 |
gr.Markdown(
|
@@ -27,7 +29,10 @@ def main_page():
|
|
27 |
""",
|
28 |
elem_id="page_title",
|
29 |
)
|
30 |
-
with gr.Column(
|
|
|
|
|
|
|
31 |
global_base_model_select = gr.Dropdown(
|
32 |
label="Base Model",
|
33 |
elem_id="global_base_model_select",
|
@@ -99,6 +104,19 @@ def main_page():
|
|
99 |
]
|
100 |
)
|
101 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
main_page_blocks.load(_js=f"""
|
103 |
function () {{
|
104 |
{popperjs_core_code()}
|
@@ -239,6 +257,12 @@ def main_page_custom_css():
|
|
239 |
}
|
240 |
*/
|
241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
.error-message, .error-message p {
|
243 |
color: var(--error-text-color) !important;
|
244 |
}
|
@@ -261,6 +285,36 @@ def main_page_custom_css():
|
|
261 |
max-height: unset;
|
262 |
}
|
263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
264 |
#page_title {
|
265 |
flex-grow: 3;
|
266 |
}
|
|
|
18 |
title=title,
|
19 |
css=get_css_styles(),
|
20 |
) as main_page_blocks:
|
21 |
+
training_indicator = gr.HTML(
|
22 |
+
"", visible=False, elem_id="training_indicator")
|
23 |
with gr.Column(elem_id="main_page_content"):
|
24 |
with gr.Row():
|
25 |
gr.Markdown(
|
|
|
29 |
""",
|
30 |
elem_id="page_title",
|
31 |
)
|
32 |
+
with gr.Column(
|
33 |
+
elem_id="global_base_model_select_group",
|
34 |
+
elem_classes="disable_while_training without_message"
|
35 |
+
):
|
36 |
global_base_model_select = gr.Dropdown(
|
37 |
label="Base Model",
|
38 |
elem_id="global_base_model_select",
|
|
|
104 |
]
|
105 |
)
|
106 |
|
107 |
+
main_page_blocks.load(
|
108 |
+
fn=lambda: gr.HTML.update(
|
109 |
+
visible=Global.is_training or Global.is_train_starting,
|
110 |
+
value=Global.is_training and "training"
|
111 |
+
or (
|
112 |
+
Global.is_train_starting and "train_starting" or ""
|
113 |
+
)
|
114 |
+
),
|
115 |
+
inputs=None,
|
116 |
+
outputs=[training_indicator],
|
117 |
+
every=2
|
118 |
+
)
|
119 |
+
|
120 |
main_page_blocks.load(_js=f"""
|
121 |
function () {{
|
122 |
{popperjs_core_code()}
|
|
|
257 |
}
|
258 |
*/
|
259 |
|
260 |
+
.hide_wrap > .wrap {
|
261 |
+
border: 0;
|
262 |
+
background: transparent;
|
263 |
+
pointer-events: none;
|
264 |
+
}
|
265 |
+
|
266 |
.error-message, .error-message p {
|
267 |
color: var(--error-text-color) !important;
|
268 |
}
|
|
|
285 |
max-height: unset;
|
286 |
}
|
287 |
|
288 |
+
#training_indicator { display: none; }
|
289 |
+
#training_indicator:not(.hidden) ~ * .disable_while_training {
|
290 |
+
position: relative !important;
|
291 |
+
pointer-events: none !important;
|
292 |
+
}
|
293 |
+
#training_indicator:not(.hidden) ~ * .disable_while_training * {
|
294 |
+
pointer-events: none !important;
|
295 |
+
}
|
296 |
+
#training_indicator:not(.hidden) ~ * .disable_while_training::after {
|
297 |
+
content: "Disabled while training is in progress";
|
298 |
+
display: flex;
|
299 |
+
position: absolute !important;
|
300 |
+
z-index: 70;
|
301 |
+
top: 0;
|
302 |
+
left: 0;
|
303 |
+
right: 0;
|
304 |
+
bottom: 0;
|
305 |
+
background: var(--block-background-fill);
|
306 |
+
opacity: 0.7;
|
307 |
+
justify-content: center;
|
308 |
+
align-items: center;
|
309 |
+
color: var(--body-text-color);
|
310 |
+
font-size: var(--text-lg);
|
311 |
+
font-weight: var(--weight-bold);
|
312 |
+
text-transform: uppercase;
|
313 |
+
}
|
314 |
+
#training_indicator:not(.hidden) ~ * .disable_while_training.without_message::after {
|
315 |
+
content: "";
|
316 |
+
}
|
317 |
+
|
318 |
#page_title {
|
319 |
flex-grow: 3;
|
320 |
}
|
llama_lora/ui/tokenizer_ui.py
CHANGED
@@ -41,7 +41,7 @@ def tokenizer_ui():
|
|
41 |
things_that_might_timeout = []
|
42 |
|
43 |
with gr.Blocks() as tokenizer_ui_blocks:
|
44 |
-
with gr.Row():
|
45 |
with gr.Column():
|
46 |
encoded_tokens = gr.Code(
|
47 |
label="Encoded Tokens (JSON)",
|
|
|
41 |
things_that_might_timeout = []
|
42 |
|
43 |
with gr.Blocks() as tokenizer_ui_blocks:
|
44 |
+
with gr.Row(elem_classes="disable_while_training"):
|
45 |
with gr.Column():
|
46 |
encoded_tokens = gr.Code(
|
47 |
label="Encoded Tokens (JSON)",
|
llama_lora/ui/trainer_callback.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import traceback
|
3 |
+
from transformers import TrainerCallback
|
4 |
+
|
5 |
+
from ..globals import Global
|
6 |
+
from ..utils.eta_predictor import ETAPredictor
|
7 |
+
|
8 |
+
|
9 |
+
def reset_training_status():
|
10 |
+
Global.is_train_starting = False
|
11 |
+
Global.is_training = False
|
12 |
+
Global.should_stop_training = False
|
13 |
+
Global.train_started_at = time.time()
|
14 |
+
Global.training_error_message = None
|
15 |
+
Global.training_error_detail = None
|
16 |
+
Global.training_total_epochs = 1
|
17 |
+
Global.training_current_epoch = 0.0
|
18 |
+
Global.training_total_steps = 1
|
19 |
+
Global.training_current_step = 0
|
20 |
+
Global.training_progress = 0.0
|
21 |
+
Global.training_log_history = []
|
22 |
+
Global.training_status_text = ""
|
23 |
+
Global.training_eta_predictor = ETAPredictor()
|
24 |
+
Global.training_eta = None
|
25 |
+
Global.train_output = None
|
26 |
+
Global.train_output_str = None
|
27 |
+
|
28 |
+
|
29 |
+
def get_progress_text(current_epoch, total_epochs, last_loss):
|
30 |
+
progress_detail = f"Epoch {current_epoch:.2f}/{total_epochs}"
|
31 |
+
if last_loss is not None:
|
32 |
+
progress_detail += f", Loss: {last_loss:.4f}"
|
33 |
+
return f"Training... ({progress_detail})"
|
34 |
+
|
35 |
+
|
36 |
+
def set_train_output(output):
|
37 |
+
end_by = 'aborted' if Global.should_stop_training else 'completed'
|
38 |
+
result_message = f"Training {end_by}"
|
39 |
+
Global.training_status_text = result_message
|
40 |
+
|
41 |
+
Global.train_output = output
|
42 |
+
Global.train_output_str = str(output)
|
43 |
+
|
44 |
+
return result_message
|
45 |
+
|
46 |
+
|
47 |
+
def update_training_states(
|
48 |
+
current_step, total_steps,
|
49 |
+
current_epoch, total_epochs,
|
50 |
+
log_history):
|
51 |
+
|
52 |
+
Global.training_total_steps = total_steps
|
53 |
+
Global.training_current_step = current_step
|
54 |
+
Global.training_total_epochs = total_epochs
|
55 |
+
Global.training_current_epoch = current_epoch
|
56 |
+
Global.training_progress = current_step / total_steps
|
57 |
+
Global.training_log_history = log_history
|
58 |
+
Global.training_eta = Global.training_eta_predictor.predict_eta(current_step, total_steps)
|
59 |
+
|
60 |
+
last_history = None
|
61 |
+
last_loss = None
|
62 |
+
if len(Global.training_log_history) > 0:
|
63 |
+
last_history = log_history[-1]
|
64 |
+
last_loss = last_history.get('loss', None)
|
65 |
+
|
66 |
+
Global.training_status_text = get_progress_text(
|
67 |
+
total_epochs=total_epochs,
|
68 |
+
current_epoch=current_epoch,
|
69 |
+
last_loss=last_loss,
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
class UiTrainerCallback(TrainerCallback):
|
74 |
+
def _on_progress(self, args, state, control):
|
75 |
+
if Global.should_stop_training:
|
76 |
+
control.should_training_stop = True
|
77 |
+
|
78 |
+
try:
|
79 |
+
total_steps = (
|
80 |
+
state.max_steps if state.max_steps is not None
|
81 |
+
else state.num_train_epochs * state.steps_per_epoch)
|
82 |
+
current_step = state.global_step
|
83 |
+
|
84 |
+
total_epochs = args.num_train_epochs
|
85 |
+
current_epoch = state.epoch
|
86 |
+
|
87 |
+
log_history = state.log_history
|
88 |
+
|
89 |
+
update_training_states(
|
90 |
+
total_steps=total_steps,
|
91 |
+
current_step=current_step,
|
92 |
+
total_epochs=total_epochs,
|
93 |
+
current_epoch=current_epoch,
|
94 |
+
log_history=log_history
|
95 |
+
)
|
96 |
+
except Exception as e:
|
97 |
+
print("Error occurred while updating UI status:", e)
|
98 |
+
traceback.print_exc()
|
99 |
+
|
100 |
+
def on_epoch_begin(self, args, state, control, **kwargs):
|
101 |
+
self._on_progress(args, state, control)
|
102 |
+
|
103 |
+
def on_step_end(self, args, state, control, **kwargs):
|
104 |
+
self._on_progress(args, state, control)
|
llama_lora/utils/eta_predictor.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import traceback
|
3 |
+
from collections import deque
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
|
7 |
+
class ETAPredictor:
|
8 |
+
def __init__(self, lookback_minutes: int = 180):
|
9 |
+
self.lookback_seconds = lookback_minutes * 60 # convert minutes to seconds
|
10 |
+
self.data = deque()
|
11 |
+
|
12 |
+
def _cleanup_old_data(self):
|
13 |
+
current_time = time.time()
|
14 |
+
while self.data and current_time - self.data[0][1] > self.lookback_seconds:
|
15 |
+
self.data.popleft()
|
16 |
+
|
17 |
+
def predict_eta(
|
18 |
+
self, current_step: int, total_steps: int
|
19 |
+
) -> Optional[int]:
|
20 |
+
try:
|
21 |
+
current_time = time.time()
|
22 |
+
|
23 |
+
# Calculate dynamic log interval based on current logged data
|
24 |
+
log_interval = 1
|
25 |
+
if len(self.data) > 100:
|
26 |
+
log_interval = 10
|
27 |
+
|
28 |
+
# Only log data if last log is at least log_interval seconds ago
|
29 |
+
if len(self.data) < 1 or current_time - self.data[-1][1] >= log_interval:
|
30 |
+
self.data.append((current_step, current_time))
|
31 |
+
self._cleanup_old_data()
|
32 |
+
|
33 |
+
# Only predict if we have enough data
|
34 |
+
if len(self.data) < 2 or self.data[-1][1] - self.data[0][1] < 5:
|
35 |
+
return None
|
36 |
+
|
37 |
+
first_step, first_time = self.data[0]
|
38 |
+
steps_completed = current_step - first_step
|
39 |
+
time_elapsed = current_time - first_time
|
40 |
+
|
41 |
+
if steps_completed == 0:
|
42 |
+
return None
|
43 |
+
|
44 |
+
time_per_step = time_elapsed / steps_completed
|
45 |
+
steps_remaining = total_steps - current_step
|
46 |
+
|
47 |
+
remaining_seconds = steps_remaining * time_per_step
|
48 |
+
eta_unix_timestamp = current_time + remaining_seconds
|
49 |
+
|
50 |
+
return int(eta_unix_timestamp)
|
51 |
+
except Exception as e:
|
52 |
+
print("Error predicting ETA:", e)
|
53 |
+
traceback.print_exc()
|
54 |
+
return None
|