zetavg commited on
Commit
79d936d
·
unverified ·
1 Parent(s): 9cd5ad7

make the training process async

Browse files
README.md CHANGED
@@ -70,7 +70,7 @@ setup: |
70
  # Start the app.
71
  run: |
72
  echo 'Starting...'
73
- python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key="$([ -f /data/secrets/wandb_api_key ] && cat /data/secrets/wandb_api_key | tr -d '\n')" --base_model=decapoda-research/llama-7b-hf --base_model_choices='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j,databricks/dolly-v2-7b --share
74
  ```
75
 
76
  Then launch a cluster to run the task:
@@ -100,7 +100,7 @@ When you are done, run `sky stop <cluster_name>` to stop the cluster. To termina
100
 
101
  ```bash
102
  pip install -r requirements.lock.txt
103
- python app.py --data_dir='./data' --base_model='decapoda-research/llama-7b-hf' --share
104
  ```
105
 
106
  You will see the local and public URLs of the app in the terminal. Open the URL in your browser to use the app.
 
70
  # Start the app.
71
  run: |
72
  echo 'Starting...'
73
+ python llama_lora_tuner/app.py --data_dir='/data' --wandb_api_key="$([ -f /data/secrets/wandb_api_key ] && cat /data/secrets/wandb_api_key | tr -d '\n')" --timezone='Atlantic/Reykjavik' --base_model=decapoda-research/llama-7b-hf --base_model_choices='decapoda-research/llama-7b-hf,nomic-ai/gpt4all-j,databricks/dolly-v2-7b --share
74
  ```
75
 
76
  Then launch a cluster to run the task:
 
100
 
101
  ```bash
102
  pip install -r requirements.lock.txt
103
+ python app.py --data_dir='./data' --base_model='decapoda-research/llama-7b-hf' --timezone='Atlantic/Reykjavik' --share
104
  ```
105
 
106
  You will see the local and public URLs of the app in the terminal. Open the URL in your browser to use the app.
app.py CHANGED
@@ -28,6 +28,7 @@ def main(
28
  ui_dev_mode: Union[bool, None] = None,
29
  wandb_api_key: Union[str, None] = None,
30
  wandb_project: Union[str, None] = None,
 
31
  ):
32
  '''
33
  Start the LLaMA-LoRA Tuner UI.
@@ -76,6 +77,9 @@ def main(
76
  if wandb_project is not None:
77
  Config.default_wandb_project = wandb_project
78
 
 
 
 
79
  if ui_dev_mode is not None:
80
  Config.ui_dev_mode = ui_dev_mode
81
 
 
28
  ui_dev_mode: Union[bool, None] = None,
29
  wandb_api_key: Union[str, None] = None,
30
  wandb_project: Union[str, None] = None,
31
+ timezone: Union[str, None] = None,
32
  ):
33
  '''
34
  Start the LLaMA-LoRA Tuner UI.
 
77
  if wandb_project is not None:
78
  Config.default_wandb_project = wandb_project
79
 
80
+ if timezone is not None:
81
+ Config.timezone = timezone
82
+
83
  if ui_dev_mode is not None:
84
  Config.ui_dev_mode = ui_dev_mode
85
 
config.yaml.sample CHANGED
@@ -9,6 +9,8 @@ base_model_choices:
9
  load_8bit: false
10
  trust_remote_code: false
11
 
 
 
12
  # UI Customization
13
  # ui_title: LLM Tuner
14
  # ui_emoji: 🦙🎛️
 
9
  load_8bit: false
10
  trust_remote_code: false
11
 
12
+ # timezone: Atlantic/Reykjavik
13
+
14
  # UI Customization
15
  # ui_title: LLM Tuner
16
  # ui_emoji: 🦙🎛️
llama_lora/config.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
- from typing import List, Union
 
3
 
4
 
5
  class Config:
@@ -15,6 +16,8 @@ class Config:
15
 
16
  trust_remote_code: bool = False
17
 
 
 
18
  # WandB
19
  enable_wandb: Union[bool, None] = False
20
  wandb_api_key: Union[str, None] = None
@@ -37,6 +40,9 @@ def process_config():
37
  base_model_choices = [name.strip() for name in base_model_choices]
38
  Config.base_model_choices = base_model_choices
39
 
 
 
 
40
  if Config.default_base_model_name not in Config.base_model_choices:
41
  Config.base_model_choices = [Config.default_base_model_name] + Config.base_model_choices
42
 
 
1
  import os
2
+ import pytz
3
+ from typing import List, Union, Any
4
 
5
 
6
  class Config:
 
16
 
17
  trust_remote_code: bool = False
18
 
19
+ timezone: Any = pytz.UTC
20
+
21
  # WandB
22
  enable_wandb: Union[bool, None] = False
23
  wandb_api_key: Union[str, None] = None
 
40
  base_model_choices = [name.strip() for name in base_model_choices]
41
  Config.base_model_choices = base_model_choices
42
 
43
+ if isinstance(Config.timezone, str):
44
+ Config.timezone = pytz.timezone(Config.timezone)
45
+
46
  if Config.default_base_model_name not in Config.base_model_choices:
47
  Config.base_model_choices = [Config.default_base_model_name] + Config.base_model_choices
48
 
llama_lora/globals.py CHANGED
@@ -12,6 +12,7 @@ import nvidia_smi
12
  from .dynamic_import import dynamic_import
13
  from .config import Config
14
  from .utils.lru_cache import LRUCache
 
15
 
16
 
17
  class Global:
@@ -31,6 +32,24 @@ class Global:
31
  # Training Control
32
  should_stop_training: bool = False
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Generation Control
35
  should_stop_generating: bool = False
36
  generation_force_stopped_at: Union[float, None] = None
 
12
  from .dynamic_import import dynamic_import
13
  from .config import Config
14
  from .utils.lru_cache import LRUCache
15
+ from .utils.eta_predictor import ETAPredictor
16
 
17
 
18
  class Global:
 
32
  # Training Control
33
  should_stop_training: bool = False
34
 
35
+ # Training Status
36
+ is_train_starting: bool = False
37
+ is_training: bool = False
38
+ train_started_at: float = 0.0
39
+ training_error_message: Union[str, None] = None
40
+ training_error_detail: Union[str, None] = None
41
+ training_total_epochs: int = 0
42
+ training_current_epoch: float = 0.0
43
+ training_total_steps: int = 0
44
+ training_current_step: int = 0
45
+ training_progress: float = 0.0
46
+ training_log_history: List[Any] = []
47
+ training_status_text: str = ""
48
+ training_eta_predictor = ETAPredictor()
49
+ training_eta: Union[int, None] = None
50
+ train_output: Union[None, Any] = None
51
+ train_output_str: Union[None, str] = None
52
+
53
  # Generation Control
54
  should_stop_generating: bool = False
55
  generation_force_stopped_at: Union[float, None] = None
llama_lora/models.py CHANGED
@@ -26,6 +26,8 @@ def get_peft_model_class():
26
  def get_new_base_model(base_model_name):
27
  if Config.ui_dev_mode:
28
  return
 
 
29
 
30
  if Global.new_base_model_that_is_ready_to_be_used:
31
  if Global.name_of_new_base_model_that_is_ready_to_be_used == base_model_name:
@@ -121,6 +123,9 @@ def get_tokenizer(base_model_name):
121
  if Config.ui_dev_mode:
122
  return
123
 
 
 
 
124
  loaded_tokenizer = Global.loaded_tokenizers.get(base_model_name)
125
  if loaded_tokenizer:
126
  return loaded_tokenizer
@@ -150,6 +155,9 @@ def get_model(
150
  if Config.ui_dev_mode:
151
  return
152
 
 
 
 
153
  if peft_model_name == "None":
154
  peft_model_name = None
155
 
 
26
  def get_new_base_model(base_model_name):
27
  if Config.ui_dev_mode:
28
  return
29
+ if Global.is_train_starting or Global.is_training:
30
+ raise Exception("Cannot load new base model while training.")
31
 
32
  if Global.new_base_model_that_is_ready_to_be_used:
33
  if Global.name_of_new_base_model_that_is_ready_to_be_used == base_model_name:
 
123
  if Config.ui_dev_mode:
124
  return
125
 
126
+ if Global.is_train_starting or Global.is_training:
127
+ raise Exception("Cannot load new base model while training.")
128
+
129
  loaded_tokenizer = Global.loaded_tokenizers.get(base_model_name)
130
  if loaded_tokenizer:
131
  return loaded_tokenizer
 
155
  if Config.ui_dev_mode:
156
  return
157
 
158
+ if Global.is_train_starting or Global.is_training:
159
+ raise Exception("Cannot load new base model while training.")
160
+
161
  if peft_model_name == "None":
162
  peft_model_name = None
163
 
llama_lora/ui/finetune/finetune_ui.py CHANGED
@@ -27,7 +27,8 @@ from .previewing import (
27
  refresh_dataset_items_count,
28
  )
29
  from .training import (
30
- do_train
 
31
  )
32
 
33
  register_css_style('finetune', relative_read_file(__file__, "style.css"))
@@ -770,19 +771,22 @@ def finetune_ui():
770
  )
771
  )
772
 
773
- train_output = gr.Text(
774
  "Training results will be shown here.",
775
  label="Train Output",
776
  elem_id="finetune_training_status")
777
 
778
- train_progress = train_btn.click(
 
 
 
779
  fn=do_train,
780
  inputs=(dataset_inputs + finetune_args + [
781
  model_name,
782
  continue_from_model,
783
  continue_from_checkpoint,
784
  ]),
785
- outputs=train_output
786
  )
787
 
788
  # controlled by JS, shows the confirm_abort_button
@@ -790,13 +794,20 @@ def finetune_ui():
790
  confirm_abort_button.click(
791
  fn=do_abort_training,
792
  inputs=None, outputs=None,
793
- cancels=[train_progress])
794
-
795
- stop_timeoutable_btn = gr.Button(
796
- "stop not-responding elements",
797
- elem_id="inference_stop_timeoutable_btn",
798
- elem_classes="foot_stop_timeoutable_btn")
799
- stop_timeoutable_btn.click(
800
- fn=None, inputs=None, outputs=None, cancels=things_that_might_timeout)
801
 
 
 
 
 
 
 
802
  finetune_ui_blocks.load(_js=relative_read_file(__file__, "script.js"))
 
 
 
 
 
 
 
 
 
27
  refresh_dataset_items_count,
28
  )
29
  from .training import (
30
+ do_train,
31
+ render_training_status
32
  )
33
 
34
  register_css_style('finetune', relative_read_file(__file__, "style.css"))
 
771
  )
772
  )
773
 
774
+ train_status = gr.HTML(
775
  "Training results will be shown here.",
776
  label="Train Output",
777
  elem_id="finetune_training_status")
778
 
779
+ training_indicator = gr.HTML(
780
+ "training_indicator", visible=False, elem_id="finetune_training_indicator")
781
+
782
+ train_start = train_btn.click(
783
  fn=do_train,
784
  inputs=(dataset_inputs + finetune_args + [
785
  model_name,
786
  continue_from_model,
787
  continue_from_checkpoint,
788
  ]),
789
+ outputs=[train_status, training_indicator]
790
  )
791
 
792
  # controlled by JS, shows the confirm_abort_button
 
794
  confirm_abort_button.click(
795
  fn=do_abort_training,
796
  inputs=None, outputs=None,
797
+ cancels=[train_start])
 
 
 
 
 
 
 
798
 
799
+ training_status_updates = finetune_ui_blocks.load(
800
+ fn=render_training_status,
801
+ inputs=None,
802
+ outputs=[train_status, training_indicator],
803
+ every=0.1
804
+ )
805
  finetune_ui_blocks.load(_js=relative_read_file(__file__, "script.js"))
806
+
807
+ # things_that_might_timeout.append(training_status_updates)
808
+ stop_timeoutable_btn = gr.Button(
809
+ "stop not-responding elements",
810
+ elem_id="inference_stop_timeoutable_btn",
811
+ elem_classes="foot_stop_timeoutable_btn")
812
+ stop_timeoutable_btn.click(
813
+ fn=None, inputs=None, outputs=None, cancels=things_that_might_timeout)
llama_lora/ui/finetune/script.js CHANGED
@@ -130,10 +130,10 @@ function finetune_ui_blocks_js() {
130
 
131
  // Show/hide start and stop button base on the state.
132
  setTimeout(function () {
133
- // Make the '#finetune_training_status > .wrap' element appear
134
- if (!document.querySelector('#finetune_training_status > .wrap')) {
135
- document.getElementById('finetune_confirm_stop_btn').click();
136
- }
137
 
138
  setTimeout(function () {
139
  let resetStopButtonTimer;
@@ -156,11 +156,20 @@ function finetune_ui_blocks_js() {
156
  document.getElementById('finetune_confirm_stop_btn').style.display =
157
  'block';
158
  });
159
- const output_wrap_element = document.querySelector(
160
- '#finetune_training_status > .wrap'
 
 
 
161
  );
162
- function handle_output_wrap_element_class_change() {
163
- if (Array.from(output_wrap_element.classList).includes('hide')) {
 
 
 
 
 
 
164
  if (resetStopButtonTimer) clearTimeout(resetStopButtonTimer);
165
  document.getElementById('finetune_start_btn').style.display = 'block';
166
  document.getElementById('finetune_stop_btn').style.display = 'none';
@@ -173,13 +182,19 @@ function finetune_ui_blocks_js() {
173
  'none';
174
  }
175
  }
 
 
 
 
 
 
176
  new MutationObserver(function (mutationsList, observer) {
177
- handle_output_wrap_element_class_change();
178
- }).observe(output_wrap_element, {
179
  attributes: true,
180
  attributeFilter: ['class'],
181
  });
182
- handle_output_wrap_element_class_change();
183
  }, 500);
184
  }, 0);
185
  }
 
130
 
131
  // Show/hide start and stop button base on the state.
132
  setTimeout(function () {
133
+ // Make the '#finetune_training_indicator > .wrap' element appear
134
+ // if (!document.querySelector('#finetune_training_indicator > .wrap')) {
135
+ // document.getElementById('finetune_confirm_stop_btn').click();
136
+ // }
137
 
138
  setTimeout(function () {
139
  let resetStopButtonTimer;
 
156
  document.getElementById('finetune_confirm_stop_btn').style.display =
157
  'block';
158
  });
159
+ // const training_indicator_wrap_element = document.querySelector(
160
+ // '#finetune_training_indicator > .wrap'
161
+ // );
162
+ const training_indicator_element = document.querySelector(
163
+ '#finetune_training_indicator'
164
  );
165
+ let isTraining = undefined;
166
+ function handle_training_indicator_change() {
167
+ // const wrapperHidden = Array.from(training_indicator_wrap_element.classList).includes('hide');
168
+ const hidden = Array.from(training_indicator_element.classList).includes('hidden');
169
+ const newIsTraining = !(/* wrapperHidden && */ hidden);
170
+ if (newIsTraining === isTraining) return;
171
+ isTraining = newIsTraining;
172
+ if (!isTraining) {
173
  if (resetStopButtonTimer) clearTimeout(resetStopButtonTimer);
174
  document.getElementById('finetune_start_btn').style.display = 'block';
175
  document.getElementById('finetune_stop_btn').style.display = 'none';
 
182
  'none';
183
  }
184
  }
185
+ // new MutationObserver(function (mutationsList, observer) {
186
+ // handle_training_indicator_change();
187
+ // }).observe(training_indicator_wrap_element, {
188
+ // attributes: true,
189
+ // attributeFilter: ['class'],
190
+ // });
191
  new MutationObserver(function (mutationsList, observer) {
192
+ handle_training_indicator_change();
193
+ }).observe(training_indicator_element, {
194
  attributes: true,
195
  attributeFilter: ['class'],
196
  });
197
+ handle_training_indicator_change();
198
  }, 500);
199
  }, 0);
200
  }
llama_lora/ui/finetune/style.css CHANGED
@@ -255,8 +255,118 @@
255
  display: none;
256
  }
257
 
258
- /* in case if there's too many logs on the previous run and made the box too high */
259
- #finetune_training_status:has(.wrap:not(.hide)) {
260
- max-height: 160px;
261
- height: 160px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  display: none;
256
  }
257
 
258
+ #finetune_training_status > .wrap {
259
+ border: 0;
260
+ background: transparent;
261
+ pointer-events: none;
262
+ top: 0;
263
+ bottom: 0;
264
+ left: 0;
265
+ right: 0;
266
+ }
267
+ #finetune_training_status > .wrap .meta-text-center {
268
+ transform: none !important;
269
+ }
270
+
271
+ #finetune_training_status .progress-block {
272
+ min-height: 100px;
273
+ display: flex;
274
+ justify-content: center;
275
+ align-items: center;
276
+ background: var(--panel-background-fill);
277
+ border-radius: var(--radius-lg);
278
+ border: var(--block-border-width) solid var(--border-color-primary);
279
+ padding: var(--block-padding);
280
+ }
281
+ #finetune_training_status .progress-block.is_training {
282
+ min-height: 160px;
283
+ }
284
+ #finetune_training_status .progress-block .empty-text {
285
+ text-transform: uppercase;
286
+ font-weight: 700;
287
+ font-size: 120%;
288
+ opacity: 0.12;
289
+ }
290
+ #finetune_training_status .progress-block .meta-text {
291
+ position: absolute;
292
+ top: 0;
293
+ right: 0;
294
+ z-index: var(--layer-2);
295
+ padding: var(--size-1) var(--size-2);
296
+ font-size: var(--text-sm);
297
+ font-family: var(--font-mono);
298
+ }
299
+ #finetune_training_status .progress-block .status {
300
+ white-space: pre-wrap;
301
+ }
302
+ #finetune_training_status .progress-block .progress-level {
303
+ display: flex;
304
+ flex-direction: column;
305
+ align-items: center;
306
+ z-index: var(--layer-2);
307
+ width: var(--size-full);
308
+ }
309
+ #finetune_training_status .progress-block .progress-level-inner {
310
+ margin: var(--size-2) auto;
311
+ color: var(--body-text-color);
312
+ font-size: var(--text-sm);
313
+ font-family: var(--font-mono);
314
+ }
315
+ #finetune_training_status .progress-block .progress-bar-wrap {
316
+ border: 1px solid var(--border-color-primary);
317
+ background: var(--background-fill-primary);
318
+ width: 55.5%;
319
+ height: var(--size-4);
320
+ }
321
+ #finetune_training_status .progress-block .progress-bar {
322
+ transform-origin: left;
323
+ background-color: var(--loader-color);
324
+ width: var(--size-full);
325
+ height: var(--size-full);
326
+ transition: all 150ms ease 0s;
327
  }
328
+
329
+ #finetune_training_status .progress-block .output {
330
+ display: flex;
331
+ flex-direction: column;
332
+ justify-content: center;
333
+ align-items: center;
334
+ }
335
+ #finetune_training_status .progress-block .output .title {
336
+ padding: var(--size-1) var(--size-3);
337
+ font-weight: var(--weight-bold);
338
+ font-size: var(--text-lg);
339
+ line-height: var(--line-xs);
340
+ }
341
+ #finetune_training_status .progress-block .output .message {
342
+ padding: var(--size-1) var(--size-3);
343
+ color: var(--body-text-color) !important;
344
+ font-family: var(--font-mono);
345
+ white-space: pre-wrap;
346
+ }
347
+
348
+ #finetune_training_status .progress-block .error {
349
+ display: flex;
350
+ flex-direction: column;
351
+ justify-content: center;
352
+ align-items: center;
353
+ }
354
+ #finetune_training_status .progress-block .error .title {
355
+ padding: var(--size-1) var(--size-3);
356
+ color: var(--color-red-500);
357
+ font-weight: var(--weight-bold);
358
+ font-size: var(--text-lg);
359
+ line-height: var(--line-xs);
360
+ }
361
+ #finetune_training_status .progress-block .error .error-message {
362
+ padding: var(--size-1) var(--size-3);
363
+ color: var(--body-text-color) !important;
364
+ font-family: var(--font-mono);
365
+ white-space: pre-wrap;
366
+ }
367
+ #finetune_training_status .progress-block.is_error {
368
+ /* background: var(--error-background-fill) !important; */
369
+ border: 1px solid var(--error-border-color) !important;
370
+ }
371
+
372
+ #finetune_training_indicator { display: none; }
llama_lora/ui/finetune/training.py CHANGED
@@ -1,24 +1,26 @@
1
  import os
2
  import json
3
  import time
 
 
 
 
 
4
  import gradio as gr
5
- import math
6
 
7
- from transformers import TrainerCallback
8
  from huggingface_hub import try_to_load_from_cache, snapshot_download
9
 
10
  from ...config import Config
11
  from ...globals import Global
12
  from ...models import clear_cache, unload_models
13
  from ...utils.prompter import Prompter
 
 
 
 
14
 
15
  from .data_processing import get_data_from_input
16
 
17
- should_training_progress_track_tqdm = True
18
-
19
- if Global.gpu_total_cores is not None and Global.gpu_total_cores > 2560:
20
- should_training_progress_track_tqdm = False
21
-
22
 
23
  def do_train(
24
  # Dataset
@@ -55,8 +57,14 @@ def do_train(
55
  model_name,
56
  continue_from_model,
57
  continue_from_checkpoint,
58
- progress=gr.Progress(track_tqdm=should_training_progress_track_tqdm),
59
  ):
 
 
 
 
 
 
60
  try:
61
  base_model_name = Global.base_model_name
62
  tokenizer_name = Global.tokenizer_name or Global.base_model_name
@@ -115,18 +123,47 @@ def do_train(
115
  raise ValueError(
116
  f"The output directory already exists and is not empty. ({output_dir})")
117
 
118
- if not should_training_progress_track_tqdm:
119
- progress(0, desc="Preparing train data...")
 
 
 
120
 
121
- # Need RAM for training
122
- unload_models()
123
- Global.new_base_model_that_is_ready_to_be_used = None
124
- Global.name_of_new_base_model_that_is_ready_to_be_used = None
125
- clear_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  prompter = Prompter(template)
128
- # variable_names = prompter.get_variable_names()
129
-
130
  data = get_data_from_input(
131
  load_dataset_from=load_dataset_from,
132
  dataset_text=dataset_text,
@@ -138,208 +175,234 @@ def do_train(
138
  prompter=prompter
139
  )
140
 
141
- train_data = prompter.get_train_data_from_dataset(data)
142
-
143
- def get_progress_text(epoch, epochs, last_loss):
144
- progress_detail = f"Epoch {math.ceil(epoch)}/{epochs}"
145
- if last_loss is not None:
146
- progress_detail += f", Loss: {last_loss:.4f}"
147
- return f"Training... ({progress_detail})"
148
-
149
- if Config.ui_dev_mode:
150
- Global.should_stop_training = False
151
-
152
- message = f"""Currently in UI dev mode, not doing the actual training.
153
-
154
- Train options: {json.dumps({
155
- 'max_seq_length': max_seq_length,
156
- 'val_set_size': evaluate_data_count,
157
- 'micro_batch_size': micro_batch_size,
158
- 'gradient_accumulation_steps': gradient_accumulation_steps,
159
- 'epochs': epochs,
160
- 'learning_rate': learning_rate,
161
- 'train_on_inputs': train_on_inputs,
162
- 'lora_r': lora_r,
163
- 'lora_alpha': lora_alpha,
164
- 'lora_dropout': lora_dropout,
165
- 'lora_target_modules': lora_target_modules,
166
- 'lora_modules_to_save': lora_modules_to_save,
167
- 'load_in_8bit': load_in_8bit,
168
- 'fp16': fp16,
169
- 'bf16': bf16,
170
- 'gradient_checkpointing': gradient_checkpointing,
171
- 'model_name': model_name,
172
- 'continue_from_model': continue_from_model,
173
- 'continue_from_checkpoint': continue_from_checkpoint,
174
- 'resume_from_checkpoint_param': resume_from_checkpoint_param,
175
- }, indent=2)}
176
-
177
- Train data (first 10):
178
- {json.dumps(train_data[:10], indent=2)}
179
- """
180
- print(message)
181
-
182
- for i in range(300):
183
- if (Global.should_stop_training):
 
 
 
184
  return
185
- epochs = 3
186
- epoch = i / 100
187
- last_loss = None
188
- if (i > 20):
189
- last_loss = 3 + (i - 0) * (0.5 - 3) / (300 - 0)
190
-
191
- progress(
192
- (i, 300),
193
- desc="(Simulate) " +
194
- get_progress_text(epoch, epochs, last_loss)
195
- )
196
-
197
- time.sleep(0.1)
198
-
199
- time.sleep(2)
200
- return message
201
-
202
- if not should_training_progress_track_tqdm:
203
- progress(
204
- 0, desc=f"Preparing model {base_model_name} for training...")
205
-
206
- log_history = []
207
-
208
- class UiTrainerCallback(TrainerCallback):
209
- def _on_progress(self, args, state, control):
210
- nonlocal log_history
211
 
212
- if Global.should_stop_training:
213
- control.should_training_stop = True
214
- total_steps = (
215
- state.max_steps if state.max_steps is not None else state.num_train_epochs * state.steps_per_epoch)
216
- log_history = state.log_history
217
- last_history = None
218
- last_loss = None
219
- if len(log_history) > 0:
220
- last_history = log_history[-1]
221
- last_loss = last_history.get('loss', None)
222
-
223
- progress_detail = f"Epoch {math.ceil(state.epoch)}/{epochs}"
224
- if last_loss is not None:
225
- progress_detail += f", Loss: {last_loss:.4f}"
226
-
227
- progress(
228
- (state.global_step, total_steps),
229
- desc=f"Training... ({progress_detail})"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  )
231
 
232
- def on_epoch_begin(self, args, state, control, **kwargs):
233
- self._on_progress(args, state, control)
234
-
235
- def on_step_end(self, args, state, control, **kwargs):
236
- self._on_progress(args, state, control)
237
-
238
- training_callbacks = [UiTrainerCallback]
239
-
240
- Global.should_stop_training = False
241
-
242
- # Do not let other tqdm iterations interfere the progress reporting after training starts.
243
- # progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
244
 
245
- if not os.path.exists(output_dir):
246
- os.makedirs(output_dir)
247
 
248
- with open(os.path.join(output_dir, "info.json"), 'w') as info_json_file:
249
- dataset_name = "N/A (from text input)"
250
- if load_dataset_from == "Data Dir":
251
- dataset_name = dataset_from_data_dir
252
 
253
- info = {
254
- 'base_model': base_model_name,
255
- 'prompt_template': template,
256
- 'dataset_name': dataset_name,
257
- 'dataset_rows': len(train_data),
258
- 'timestamp': time.time(),
259
 
260
- # These will be saved in another JSON file by the train function
261
- # 'max_seq_length': max_seq_length,
262
- # 'train_on_inputs': train_on_inputs,
263
-
264
- # 'micro_batch_size': micro_batch_size,
265
- # 'gradient_accumulation_steps': gradient_accumulation_steps,
266
- # 'epochs': epochs,
267
- # 'learning_rate': learning_rate,
268
-
269
- # 'evaluate_data_count': evaluate_data_count,
270
-
271
- # 'lora_r': lora_r,
272
- # 'lora_alpha': lora_alpha,
273
- # 'lora_dropout': lora_dropout,
274
- # 'lora_target_modules': lora_target_modules,
275
- }
276
- if continue_from_model:
277
- info['continued_from_model'] = continue_from_model
278
- if continue_from_checkpoint:
279
- info['continued_from_checkpoint'] = continue_from_checkpoint
280
-
281
- if Global.version:
282
- info['tuner_version'] = Global.version
283
-
284
- json.dump(info, info_json_file, indent=2)
285
-
286
- if not should_training_progress_track_tqdm:
287
- progress(0, desc="Train starting...")
288
-
289
- wandb_group = template
290
- wandb_tags = [f"template:{template}"]
291
- if load_dataset_from == "Data Dir" and dataset_from_data_dir:
292
- wandb_group += f"/{dataset_from_data_dir}"
293
- wandb_tags.append(f"dataset:{dataset_from_data_dir}")
294
-
295
- train_output = Global.finetune_train_fn(
296
- base_model=base_model_name,
297
- tokenizer=tokenizer_name,
298
- output_dir=output_dir,
299
- train_data=train_data,
300
- # 128, # batch_size (is not used, use gradient_accumulation_steps instead)
301
- micro_batch_size=micro_batch_size,
302
- gradient_accumulation_steps=gradient_accumulation_steps,
303
- num_train_epochs=epochs,
304
- learning_rate=learning_rate,
305
- cutoff_len=max_seq_length,
306
- val_set_size=evaluate_data_count,
307
- lora_r=lora_r,
308
- lora_alpha=lora_alpha,
309
- lora_dropout=lora_dropout,
310
- lora_target_modules=lora_target_modules,
311
- lora_modules_to_save=lora_modules_to_save,
312
- train_on_inputs=train_on_inputs,
313
- load_in_8bit=load_in_8bit,
314
- fp16=fp16,
315
- bf16=bf16,
316
- gradient_checkpointing=gradient_checkpointing,
317
- group_by_length=False,
318
- resume_from_checkpoint=resume_from_checkpoint_param,
319
- save_steps=save_steps,
320
- save_total_limit=save_total_limit,
321
- logging_steps=logging_steps,
322
- additional_training_arguments=additional_training_arguments,
323
- additional_lora_config=additional_lora_config,
324
- callbacks=training_callbacks,
325
- wandb_api_key=Config.wandb_api_key,
326
- wandb_project=Config.default_wandb_project if Config.enable_wandb else None,
327
- wandb_group=wandb_group,
328
- wandb_run_name=model_name,
329
- wandb_tags=wandb_tags
330
- )
331
-
332
- logs_str = "\n".join([json.dumps(log)
333
- for log in log_history]) or "None"
334
-
335
- result_message = f"Training ended:\n{str(train_output)}"
336
- print(result_message)
337
- # result_message += f"\n\nLogs:\n{logs_str}"
338
-
339
- clear_cache()
340
-
341
- return result_message
342
 
343
  except Exception as e:
344
- raise gr.Error(
345
- f"{e} (To dismiss this error, click the 'Abort' button)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import json
3
  import time
4
+ import datetime
5
+ import pytz
6
+ import socket
7
+ import threading
8
+ import traceback
9
  import gradio as gr
 
10
 
 
11
  from huggingface_hub import try_to_load_from_cache, snapshot_download
12
 
13
  from ...config import Config
14
  from ...globals import Global
15
  from ...models import clear_cache, unload_models
16
  from ...utils.prompter import Prompter
17
+ from ..trainer_callback import (
18
+ UiTrainerCallback, reset_training_status,
19
+ update_training_states, set_train_output
20
+ )
21
 
22
  from .data_processing import get_data_from_input
23
 
 
 
 
 
 
24
 
25
  def do_train(
26
  # Dataset
 
57
  model_name,
58
  continue_from_model,
59
  continue_from_checkpoint,
60
+ progress=gr.Progress(track_tqdm=False),
61
  ):
62
+ if Global.is_training:
63
+ return render_training_status()
64
+
65
+ reset_training_status()
66
+ Global.is_train_starting = True
67
+
68
  try:
69
  base_model_name = Global.base_model_name
70
  tokenizer_name = Global.tokenizer_name or Global.base_model_name
 
123
  raise ValueError(
124
  f"The output directory already exists and is not empty. ({output_dir})")
125
 
126
+ wandb_group = template
127
+ wandb_tags = [f"template:{template}"]
128
+ if load_dataset_from == "Data Dir" and dataset_from_data_dir:
129
+ wandb_group += f"/{dataset_from_data_dir}"
130
+ wandb_tags.append(f"dataset:{dataset_from_data_dir}")
131
 
132
+ finetune_args = {
133
+ 'base_model': base_model_name,
134
+ 'tokenizer': tokenizer_name,
135
+ 'output_dir': output_dir,
136
+ 'micro_batch_size': micro_batch_size,
137
+ 'gradient_accumulation_steps': gradient_accumulation_steps,
138
+ 'num_train_epochs': epochs,
139
+ 'learning_rate': learning_rate,
140
+ 'cutoff_len': max_seq_length,
141
+ 'val_set_size': evaluate_data_count,
142
+ 'lora_r': lora_r,
143
+ 'lora_alpha': lora_alpha,
144
+ 'lora_dropout': lora_dropout,
145
+ 'lora_target_modules': lora_target_modules,
146
+ 'lora_modules_to_save': lora_modules_to_save,
147
+ 'train_on_inputs': train_on_inputs,
148
+ 'load_in_8bit': load_in_8bit,
149
+ 'fp16': fp16,
150
+ 'bf16': bf16,
151
+ 'gradient_checkpointing': gradient_checkpointing,
152
+ 'group_by_length': False,
153
+ 'resume_from_checkpoint': resume_from_checkpoint_param,
154
+ 'save_steps': save_steps,
155
+ 'save_total_limit': save_total_limit,
156
+ 'logging_steps': logging_steps,
157
+ 'additional_training_arguments': additional_training_arguments,
158
+ 'additional_lora_config': additional_lora_config,
159
+ 'wandb_api_key': Config.wandb_api_key,
160
+ 'wandb_project': Config.default_wandb_project if Config.enable_wandb else None,
161
+ 'wandb_group': wandb_group,
162
+ 'wandb_run_name': model_name,
163
+ 'wandb_tags': wandb_tags
164
+ }
165
 
166
  prompter = Prompter(template)
 
 
167
  data = get_data_from_input(
168
  load_dataset_from=load_dataset_from,
169
  dataset_text=dataset_text,
 
175
  prompter=prompter
176
  )
177
 
178
+ def training():
179
+ Global.is_training = True
180
+
181
+ try:
182
+ # Need RAM for training
183
+ unload_models()
184
+ Global.new_base_model_that_is_ready_to_be_used = None
185
+ Global.name_of_new_base_model_that_is_ready_to_be_used = None
186
+ clear_cache()
187
+
188
+ train_data = prompter.get_train_data_from_dataset(data)
189
+
190
+ if Config.ui_dev_mode:
191
+ message = "Currently in UI dev mode, not doing the actual training."
192
+ message += f"\n\nArgs: {json.dumps(finetune_args, indent=2)}"
193
+ message += f"\n\nTrain data (first 5):\n{json.dumps(train_data[:5], indent=2)}"
194
+
195
+ print(message)
196
+
197
+ total_steps = 300
198
+ for i in range(300):
199
+ if (Global.should_stop_training):
200
+ break
201
+
202
+ current_step = i + 1
203
+ total_epochs = 3
204
+ current_epoch = i / 100
205
+ log_history = []
206
+
207
+ if (i > 20):
208
+ loss = 3 + (i - 0) * (0.5 - 3) / (300 - 0)
209
+ log_history = [{'loss': loss}]
210
+
211
+ update_training_states(
212
+ total_steps=total_steps,
213
+ current_step=current_step,
214
+ total_epochs=total_epochs,
215
+ current_epoch=current_epoch,
216
+ log_history=log_history
217
+ )
218
+ time.sleep(0.1)
219
+
220
+ result_message = set_train_output(message)
221
+ print(result_message)
222
+ time.sleep(1)
223
+ Global.is_training = False
224
  return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
+ training_callbacks = [UiTrainerCallback]
227
+
228
+ if not os.path.exists(output_dir):
229
+ os.makedirs(output_dir)
230
+
231
+ with open(os.path.join(output_dir, "info.json"), 'w') as info_json_file:
232
+ dataset_name = "N/A (from text input)"
233
+ if load_dataset_from == "Data Dir":
234
+ dataset_name = dataset_from_data_dir
235
+
236
+ info = {
237
+ 'base_model': base_model_name,
238
+ 'prompt_template': template,
239
+ 'dataset_name': dataset_name,
240
+ 'dataset_rows': len(train_data),
241
+ 'trained_on_machine': socket.gethostname(),
242
+ 'timestamp': time.time(),
243
+ }
244
+ if continue_from_model:
245
+ info['continued_from_model'] = continue_from_model
246
+ if continue_from_checkpoint:
247
+ info['continued_from_checkpoint'] = continue_from_checkpoint
248
+
249
+ if Global.version:
250
+ info['tuner_version'] = Global.version
251
+
252
+ json.dump(info, info_json_file, indent=2)
253
+
254
+ train_output = Global.finetune_train_fn(
255
+ train_data=train_data,
256
+ callbacks=training_callbacks,
257
+ **finetune_args,
258
  )
259
 
260
+ result_message = set_train_output(train_output)
261
+ print(result_message + "\n" + str(train_output))
 
 
 
 
 
 
 
 
 
 
262
 
263
+ clear_cache()
 
264
 
265
+ Global.is_training = False
 
 
 
266
 
267
+ except Exception as e:
268
+ traceback.print_exc()
269
+ Global.training_error_message = str(e)
270
+ finally:
271
+ Global.is_training = False
 
272
 
273
+ training_thread = threading.Thread(target=training)
274
+ training_thread.daemon = True
275
+ training_thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  except Exception as e:
278
+ Global.is_training = False
279
+ traceback.print_exc()
280
+ Global.training_error_message = str(e)
281
+ finally:
282
+ Global.is_train_starting = False
283
+
284
+ return render_training_status()
285
+
286
+
287
+ def render_training_status():
288
+ if not Global.is_training:
289
+ if Global.is_train_starting:
290
+ html_content = """
291
+ <div class="progress-block">
292
+ <div class="progress-level">
293
+ <div class="progress-level-inner">
294
+ Starting...
295
+ </div>
296
+ </div>
297
+ </div>
298
+ """
299
+ return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
300
+
301
+ if Global.training_error_message:
302
+ html_content = f"""
303
+ <div class="progress-block is_error">
304
+ <div class="progress-level">
305
+ <div class="error">
306
+ <div class="title">
307
+ ⚠ Something went wrong
308
+ </div>
309
+ <div class="error-message">{Global.training_error_message}</div>
310
+ </div>
311
+ </div>
312
+ </div>
313
+ """
314
+ return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
315
+
316
+ if Global.train_output_str:
317
+ end_message = "✅ Training completed"
318
+ if Global.should_stop_training:
319
+ end_message = "🛑 Train aborted"
320
+ html_content = f"""
321
+ <div class="progress-block">
322
+ <div class="progress-level">
323
+ <div class="output">
324
+ <div class="title">
325
+ {end_message}
326
+ </div>
327
+ <div class="message">{Global.train_output_str}</div>
328
+ </div>
329
+ </div>
330
+ </div>
331
+ """
332
+ return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
333
+
334
+ if Global.training_status_text:
335
+ html_content = f"""
336
+ <div class="progress-block">
337
+ <div class="status">{Global.training_status_text}</div>
338
+ </div>
339
+ """
340
+ return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
341
+
342
+ html_content = """
343
+ <div class="progress-block">
344
+ <div class="empty-text">
345
+ Training status will be shown here
346
+ </div>
347
+ </div>
348
+ """
349
+ return (gr.HTML.update(value=html_content), gr.HTML.update(visible=False))
350
+
351
+ meta_info = []
352
+ meta_info.append(
353
+ f"{Global.training_current_step}/{Global.training_total_steps} steps")
354
+ current_time = time.time()
355
+ time_elapsed = current_time - Global.train_started_at
356
+ time_remaining = -1
357
+ if Global.training_eta:
358
+ time_remaining = Global.training_eta - current_time
359
+ if time_remaining >= 0:
360
+ meta_info.append(
361
+ f"{format_time(time_elapsed)}<{format_time(time_remaining)}")
362
+ meta_info.append(f"ETA: {format_timestamp(Global.training_eta)}")
363
+ else:
364
+ meta_info.append(format_time(time_elapsed))
365
+
366
+ html_content = f"""
367
+ <div class="progress-block is_training">
368
+ <div class="meta-text">{' | '.join(meta_info)}</div>
369
+ <div class="progress-level">
370
+ <div class="progress-level-inner">
371
+ {Global.training_status_text} - {Global.training_progress * 100:.2f}%
372
+ </div>
373
+ <div class="progress-bar-wrap">
374
+ <div class="progress-bar" style="width: {Global.training_progress * 100:.2f}%;">
375
+ </div>
376
+ </div>
377
+ </div>
378
+ </div>
379
+ """
380
+ return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
381
+
382
+
383
+ def format_time(seconds):
384
+ hours, remainder = divmod(seconds, 3600)
385
+ minutes, seconds = divmod(remainder, 60)
386
+ if hours == 0:
387
+ return "{:02d}:{:02d}".format(int(minutes), int(seconds))
388
+ else:
389
+ return "{:02d}:{:02d}:{:02d}".format(int(hours), int(minutes), int(seconds))
390
+
391
+
392
+ def format_timestamp(timestamp):
393
+ dt_naive = datetime.datetime.utcfromtimestamp(timestamp)
394
+ utc = pytz.UTC
395
+ timezone = Config.timezone
396
+ dt_aware = utc.localize(dt_naive).astimezone(timezone)
397
+ now = datetime.datetime.now(timezone)
398
+ delta = dt_aware.date() - now.date()
399
+ if delta.days == 0:
400
+ time_str = ""
401
+ elif delta.days == 1:
402
+ time_str = "tomorrow at "
403
+ elif delta.days == -1:
404
+ time_str = "yesterday at "
405
+ else:
406
+ time_str = dt_aware.strftime('%A, %B %d at ')
407
+ time_str += dt_aware.strftime('%I:%M %p').lower()
408
+ return time_str
llama_lora/ui/inference_ui.py CHANGED
@@ -381,7 +381,7 @@ def inference_ui():
381
  things_that_might_timeout = []
382
 
383
  with gr.Blocks() as inference_ui_blocks:
384
- with gr.Row():
385
  with gr.Column(elem_id="inference_lora_model_group"):
386
  model_prompt_template_message = gr.Markdown(
387
  "", visible=False, elem_id="inference_lora_model_prompt_template_message")
@@ -402,7 +402,7 @@ def inference_ui():
402
  reload_selections_button.style(
403
  full_width=False,
404
  size="sm")
405
- with gr.Row():
406
  with gr.Column():
407
  with gr.Column(elem_id="inference_prompt_box"):
408
  variable_0 = gr.Textbox(
 
381
  things_that_might_timeout = []
382
 
383
  with gr.Blocks() as inference_ui_blocks:
384
+ with gr.Row(elem_classes="disable_while_training"):
385
  with gr.Column(elem_id="inference_lora_model_group"):
386
  model_prompt_template_message = gr.Markdown(
387
  "", visible=False, elem_id="inference_lora_model_prompt_template_message")
 
402
  reload_selections_button.style(
403
  full_width=False,
404
  size="sm")
405
+ with gr.Row(elem_classes="disable_while_training"):
406
  with gr.Column():
407
  with gr.Column(elem_id="inference_prompt_box"):
408
  variable_0 = gr.Textbox(
llama_lora/ui/main_page.py CHANGED
@@ -18,6 +18,8 @@ def main_page():
18
  title=title,
19
  css=get_css_styles(),
20
  ) as main_page_blocks:
 
 
21
  with gr.Column(elem_id="main_page_content"):
22
  with gr.Row():
23
  gr.Markdown(
@@ -27,7 +29,10 @@ def main_page():
27
  """,
28
  elem_id="page_title",
29
  )
30
- with gr.Column(elem_id="global_base_model_select_group"):
 
 
 
31
  global_base_model_select = gr.Dropdown(
32
  label="Base Model",
33
  elem_id="global_base_model_select",
@@ -99,6 +104,19 @@ def main_page():
99
  ]
100
  )
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  main_page_blocks.load(_js=f"""
103
  function () {{
104
  {popperjs_core_code()}
@@ -239,6 +257,12 @@ def main_page_custom_css():
239
  }
240
  */
241
 
 
 
 
 
 
 
242
  .error-message, .error-message p {
243
  color: var(--error-text-color) !important;
244
  }
@@ -261,6 +285,36 @@ def main_page_custom_css():
261
  max-height: unset;
262
  }
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  #page_title {
265
  flex-grow: 3;
266
  }
 
18
  title=title,
19
  css=get_css_styles(),
20
  ) as main_page_blocks:
21
+ training_indicator = gr.HTML(
22
+ "", visible=False, elem_id="training_indicator")
23
  with gr.Column(elem_id="main_page_content"):
24
  with gr.Row():
25
  gr.Markdown(
 
29
  """,
30
  elem_id="page_title",
31
  )
32
+ with gr.Column(
33
+ elem_id="global_base_model_select_group",
34
+ elem_classes="disable_while_training without_message"
35
+ ):
36
  global_base_model_select = gr.Dropdown(
37
  label="Base Model",
38
  elem_id="global_base_model_select",
 
104
  ]
105
  )
106
 
107
+ main_page_blocks.load(
108
+ fn=lambda: gr.HTML.update(
109
+ visible=Global.is_training or Global.is_train_starting,
110
+ value=Global.is_training and "training"
111
+ or (
112
+ Global.is_train_starting and "train_starting" or ""
113
+ )
114
+ ),
115
+ inputs=None,
116
+ outputs=[training_indicator],
117
+ every=2
118
+ )
119
+
120
  main_page_blocks.load(_js=f"""
121
  function () {{
122
  {popperjs_core_code()}
 
257
  }
258
  */
259
 
260
+ .hide_wrap > .wrap {
261
+ border: 0;
262
+ background: transparent;
263
+ pointer-events: none;
264
+ }
265
+
266
  .error-message, .error-message p {
267
  color: var(--error-text-color) !important;
268
  }
 
285
  max-height: unset;
286
  }
287
 
288
+ #training_indicator { display: none; }
289
+ #training_indicator:not(.hidden) ~ * .disable_while_training {
290
+ position: relative !important;
291
+ pointer-events: none !important;
292
+ }
293
+ #training_indicator:not(.hidden) ~ * .disable_while_training * {
294
+ pointer-events: none !important;
295
+ }
296
+ #training_indicator:not(.hidden) ~ * .disable_while_training::after {
297
+ content: "Disabled while training is in progress";
298
+ display: flex;
299
+ position: absolute !important;
300
+ z-index: 70;
301
+ top: 0;
302
+ left: 0;
303
+ right: 0;
304
+ bottom: 0;
305
+ background: var(--block-background-fill);
306
+ opacity: 0.7;
307
+ justify-content: center;
308
+ align-items: center;
309
+ color: var(--body-text-color);
310
+ font-size: var(--text-lg);
311
+ font-weight: var(--weight-bold);
312
+ text-transform: uppercase;
313
+ }
314
+ #training_indicator:not(.hidden) ~ * .disable_while_training.without_message::after {
315
+ content: "";
316
+ }
317
+
318
  #page_title {
319
  flex-grow: 3;
320
  }
llama_lora/ui/tokenizer_ui.py CHANGED
@@ -41,7 +41,7 @@ def tokenizer_ui():
41
  things_that_might_timeout = []
42
 
43
  with gr.Blocks() as tokenizer_ui_blocks:
44
- with gr.Row():
45
  with gr.Column():
46
  encoded_tokens = gr.Code(
47
  label="Encoded Tokens (JSON)",
 
41
  things_that_might_timeout = []
42
 
43
  with gr.Blocks() as tokenizer_ui_blocks:
44
+ with gr.Row(elem_classes="disable_while_training"):
45
  with gr.Column():
46
  encoded_tokens = gr.Code(
47
  label="Encoded Tokens (JSON)",
llama_lora/ui/trainer_callback.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import traceback
3
+ from transformers import TrainerCallback
4
+
5
+ from ..globals import Global
6
+ from ..utils.eta_predictor import ETAPredictor
7
+
8
+
9
+ def reset_training_status():
10
+ Global.is_train_starting = False
11
+ Global.is_training = False
12
+ Global.should_stop_training = False
13
+ Global.train_started_at = time.time()
14
+ Global.training_error_message = None
15
+ Global.training_error_detail = None
16
+ Global.training_total_epochs = 1
17
+ Global.training_current_epoch = 0.0
18
+ Global.training_total_steps = 1
19
+ Global.training_current_step = 0
20
+ Global.training_progress = 0.0
21
+ Global.training_log_history = []
22
+ Global.training_status_text = ""
23
+ Global.training_eta_predictor = ETAPredictor()
24
+ Global.training_eta = None
25
+ Global.train_output = None
26
+ Global.train_output_str = None
27
+
28
+
29
+ def get_progress_text(current_epoch, total_epochs, last_loss):
30
+ progress_detail = f"Epoch {current_epoch:.2f}/{total_epochs}"
31
+ if last_loss is not None:
32
+ progress_detail += f", Loss: {last_loss:.4f}"
33
+ return f"Training... ({progress_detail})"
34
+
35
+
36
+ def set_train_output(output):
37
+ end_by = 'aborted' if Global.should_stop_training else 'completed'
38
+ result_message = f"Training {end_by}"
39
+ Global.training_status_text = result_message
40
+
41
+ Global.train_output = output
42
+ Global.train_output_str = str(output)
43
+
44
+ return result_message
45
+
46
+
47
+ def update_training_states(
48
+ current_step, total_steps,
49
+ current_epoch, total_epochs,
50
+ log_history):
51
+
52
+ Global.training_total_steps = total_steps
53
+ Global.training_current_step = current_step
54
+ Global.training_total_epochs = total_epochs
55
+ Global.training_current_epoch = current_epoch
56
+ Global.training_progress = current_step / total_steps
57
+ Global.training_log_history = log_history
58
+ Global.training_eta = Global.training_eta_predictor.predict_eta(current_step, total_steps)
59
+
60
+ last_history = None
61
+ last_loss = None
62
+ if len(Global.training_log_history) > 0:
63
+ last_history = log_history[-1]
64
+ last_loss = last_history.get('loss', None)
65
+
66
+ Global.training_status_text = get_progress_text(
67
+ total_epochs=total_epochs,
68
+ current_epoch=current_epoch,
69
+ last_loss=last_loss,
70
+ )
71
+
72
+
73
+ class UiTrainerCallback(TrainerCallback):
74
+ def _on_progress(self, args, state, control):
75
+ if Global.should_stop_training:
76
+ control.should_training_stop = True
77
+
78
+ try:
79
+ total_steps = (
80
+ state.max_steps if state.max_steps is not None
81
+ else state.num_train_epochs * state.steps_per_epoch)
82
+ current_step = state.global_step
83
+
84
+ total_epochs = args.num_train_epochs
85
+ current_epoch = state.epoch
86
+
87
+ log_history = state.log_history
88
+
89
+ update_training_states(
90
+ total_steps=total_steps,
91
+ current_step=current_step,
92
+ total_epochs=total_epochs,
93
+ current_epoch=current_epoch,
94
+ log_history=log_history
95
+ )
96
+ except Exception as e:
97
+ print("Error occurred while updating UI status:", e)
98
+ traceback.print_exc()
99
+
100
+ def on_epoch_begin(self, args, state, control, **kwargs):
101
+ self._on_progress(args, state, control)
102
+
103
+ def on_step_end(self, args, state, control, **kwargs):
104
+ self._on_progress(args, state, control)
llama_lora/utils/eta_predictor.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import traceback
3
+ from collections import deque
4
+ from typing import Optional
5
+
6
+
7
+ class ETAPredictor:
8
+ def __init__(self, lookback_minutes: int = 180):
9
+ self.lookback_seconds = lookback_minutes * 60 # convert minutes to seconds
10
+ self.data = deque()
11
+
12
+ def _cleanup_old_data(self):
13
+ current_time = time.time()
14
+ while self.data and current_time - self.data[0][1] > self.lookback_seconds:
15
+ self.data.popleft()
16
+
17
+ def predict_eta(
18
+ self, current_step: int, total_steps: int
19
+ ) -> Optional[int]:
20
+ try:
21
+ current_time = time.time()
22
+
23
+ # Calculate dynamic log interval based on current logged data
24
+ log_interval = 1
25
+ if len(self.data) > 100:
26
+ log_interval = 10
27
+
28
+ # Only log data if last log is at least log_interval seconds ago
29
+ if len(self.data) < 1 or current_time - self.data[-1][1] >= log_interval:
30
+ self.data.append((current_step, current_time))
31
+ self._cleanup_old_data()
32
+
33
+ # Only predict if we have enough data
34
+ if len(self.data) < 2 or self.data[-1][1] - self.data[0][1] < 5:
35
+ return None
36
+
37
+ first_step, first_time = self.data[0]
38
+ steps_completed = current_step - first_step
39
+ time_elapsed = current_time - first_time
40
+
41
+ if steps_completed == 0:
42
+ return None
43
+
44
+ time_per_step = time_elapsed / steps_completed
45
+ steps_remaining = total_steps - current_step
46
+
47
+ remaining_seconds = steps_remaining * time_per_step
48
+ eta_unix_timestamp = current_time + remaining_seconds
49
+
50
+ return int(eta_unix_timestamp)
51
+ except Exception as e:
52
+ print("Error predicting ETA:", e)
53
+ traceback.print_exc()
54
+ return None