zetavg commited on
Commit
6947876
1 Parent(s): b606ad0

finetune loss chart: use steps as the x axis if possible

Browse files
llama_lora/globals.py CHANGED
@@ -5,7 +5,7 @@ import psutil
5
  import math
6
 
7
  from typing import Any, Dict, List, Optional, Tuple, Union
8
-
9
  from numba import cuda
10
  import nvidia_smi
11
 
@@ -47,6 +47,7 @@ class Global:
47
  training_status_text: str = ""
48
  training_eta_predictor = ETAPredictor()
49
  training_eta: Union[int, None] = None
 
50
  train_output: Union[None, Any] = None
51
  train_output_str: Union[None, str] = None
52
  training_params_info_text: str = ""
 
5
  import math
6
 
7
  from typing import Any, Dict, List, Optional, Tuple, Union
8
+ from transformers import TrainingArguments
9
  from numba import cuda
10
  import nvidia_smi
11
 
 
47
  training_status_text: str = ""
48
  training_eta_predictor = ETAPredictor()
49
  training_eta: Union[int, None] = None
50
+ training_args: Union[TrainingArguments, None] = None
51
  train_output: Union[None, Any] = None
52
  train_output_str: Union[None, str] = None
53
  training_params_info_text: str = ""
llama_lora/ui/finetune/training.py CHANGED
@@ -12,11 +12,13 @@ import pandas as pd
12
  import gradio as gr
13
 
14
  from huggingface_hub import try_to_load_from_cache, snapshot_download
 
15
 
16
  from ...config import Config
17
  from ...globals import Global
18
  from ...models import clear_cache, unload_models
19
  from ...utils.prompter import Prompter
 
20
  from ..trainer_callback import (
21
  UiTrainerCallback, reset_training_status,
22
  update_training_states, set_train_output
@@ -202,26 +204,31 @@ def do_train(
202
  train_data = prompter.get_train_data_from_dataset(data)
203
 
204
  if Config.ui_dev_mode:
 
 
 
 
205
  message = "Currently in UI dev mode, not doing the actual training."
206
  message += f"\n\nArgs: {json.dumps(finetune_args, indent=2)}"
207
  message += f"\n\nTrain data (first 5):\n{json.dumps(train_data[:5], indent=2)}"
208
 
209
  print(message)
210
 
211
- total_steps = 300
 
212
  log_history = []
213
  initial_loss = 2
214
  loss_decay_rate = 0.8
215
- for i in range(300):
216
  if (Global.should_stop_training):
217
  break
218
 
219
  current_step = i + 1
220
- total_epochs = 3
221
- current_epoch = i / 100
222
 
223
- if (i > 20):
224
- loss = initial_loss * math.exp(-loss_decay_rate * current_epoch)
 
225
  log_history.append({
226
  'loss': loss,
227
  'learning_rate': 0.0001,
@@ -424,7 +431,10 @@ def render_loss_plot():
424
  if len(Global.training_log_history) <= 2:
425
  return (gr.Column.update(visible=False), gr.Plot.update(visible=False))
426
 
427
- training_log_history = Global.training_log_history
 
 
 
428
 
429
  loss_data = [
430
  {
@@ -436,6 +446,12 @@ def render_loss_plot():
436
  and 'epoch' in item
437
  ]
438
 
 
 
 
 
 
 
439
  source = pd.DataFrame(loss_data)
440
 
441
  highlight = alt.selection(
@@ -443,12 +459,20 @@ def render_loss_plot():
443
  on='mouseover', fields=['type'], nearest=True
444
  )
445
 
446
- base = alt.Chart(source).encode( # type: ignore
447
- x='epoch:Q',
448
- y='loss:Q',
449
- color='type:N',
450
- tooltip=['type:N', 'loss:Q', 'epoch:Q']
451
- )
 
 
 
 
 
 
 
 
452
 
453
  points = base.mark_circle().encode(
454
  opacity=alt.value(0)
 
12
  import gradio as gr
13
 
14
  from huggingface_hub import try_to_load_from_cache, snapshot_download
15
+ from transformers import TrainingArguments
16
 
17
  from ...config import Config
18
  from ...globals import Global
19
  from ...models import clear_cache, unload_models
20
  from ...utils.prompter import Prompter
21
+ from ...utils.sample_evenly import sample_evenly
22
  from ..trainer_callback import (
23
  UiTrainerCallback, reset_training_status,
24
  update_training_states, set_train_output
 
204
  train_data = prompter.get_train_data_from_dataset(data)
205
 
206
  if Config.ui_dev_mode:
207
+ Global.training_args = TrainingArguments(
208
+ logging_steps=logging_steps, output_dir=""
209
+ )
210
+
211
  message = "Currently in UI dev mode, not doing the actual training."
212
  message += f"\n\nArgs: {json.dumps(finetune_args, indent=2)}"
213
  message += f"\n\nTrain data (first 5):\n{json.dumps(train_data[:5], indent=2)}"
214
 
215
  print(message)
216
 
217
+ total_epochs = epochs
218
+ total_steps = len(train_data) * epochs
219
  log_history = []
220
  initial_loss = 2
221
  loss_decay_rate = 0.8
222
+ for i in range(total_steps):
223
  if (Global.should_stop_training):
224
  break
225
 
226
  current_step = i + 1
227
+ current_epoch = i / (total_steps / total_epochs)
 
228
 
229
+ if (current_step % logging_steps == 0):
230
+ loss = initial_loss * \
231
+ math.exp(-loss_decay_rate * current_epoch)
232
  log_history.append({
233
  'loss': loss,
234
  'learning_rate': 0.0001,
 
431
  if len(Global.training_log_history) <= 2:
432
  return (gr.Column.update(visible=False), gr.Plot.update(visible=False))
433
 
434
+ max_elements = 5000
435
+ training_log_history = sample_evenly(
436
+ Global.training_log_history, max_elements=max_elements)
437
+ logging_steps = Global.training_args and Global.training_args.logging_steps
438
 
439
  loss_data = [
440
  {
 
446
  and 'epoch' in item
447
  ]
448
 
449
+ use_steps = False
450
+ if len(Global.training_log_history) <= max_elements and logging_steps:
451
+ for index, item in enumerate(loss_data):
452
+ item["step"] = index * logging_steps
453
+ use_steps = True
454
+
455
  source = pd.DataFrame(loss_data)
456
 
457
  highlight = alt.selection(
 
459
  on='mouseover', fields=['type'], nearest=True
460
  )
461
 
462
+ if use_steps:
463
+ base = alt.Chart(source).encode( # type: ignore
464
+ x='step:Q',
465
+ y='loss:Q',
466
+ color='type:N',
467
+ tooltip=['type:N', 'loss:Q', 'step:Q', 'epoch:Q']
468
+ )
469
+ else:
470
+ base = alt.Chart(source).encode( # type: ignore
471
+ x='epoch:Q',
472
+ y='loss:Q',
473
+ color='type:N',
474
+ tooltip=['type:N', 'loss:Q', 'epoch:Q']
475
+ )
476
 
477
  points = base.mark_circle().encode(
478
  opacity=alt.value(0)
llama_lora/ui/trainer_callback.py CHANGED
@@ -22,6 +22,7 @@ def reset_training_status():
22
  Global.training_status_text = ""
23
  Global.training_eta_predictor = ETAPredictor()
24
  Global.training_eta = None
 
25
  Global.train_output = None
26
  Global.train_output_str = None
27
  Global.training_params_info_text = ""
@@ -102,6 +103,7 @@ class UiTrainerCallback(TrainerCallback):
102
  traceback.print_exc()
103
 
104
  def on_epoch_begin(self, args, state, control, **kwargs):
 
105
  self._on_progress(args, state, control)
106
 
107
  def on_step_end(self, args, state, control, **kwargs):
 
22
  Global.training_status_text = ""
23
  Global.training_eta_predictor = ETAPredictor()
24
  Global.training_eta = None
25
+ Global.training_args = None
26
  Global.train_output = None
27
  Global.train_output_str = None
28
  Global.training_params_info_text = ""
 
103
  traceback.print_exc()
104
 
105
  def on_epoch_begin(self, args, state, control, **kwargs):
106
+ Global.training_args = args
107
  self._on_progress(args, state, control)
108
 
109
  def on_step_end(self, args, state, control, **kwargs):
llama_lora/utils/sample_evenly.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Any, Iterator
3
+
4
+
5
+ def sample_evenly_it(input_list: List[Any], max_elements: int = 1000) -> Iterator[Any]:
6
+ if len(input_list) <= max_elements:
7
+ yield from input_list
8
+ else:
9
+ step = len(input_list) / max_elements
10
+ indices = np.arange(0, len(input_list), step).astype(int)
11
+ yield from (input_list[i] for i in indices)
12
+
13
+
14
+ def sample_evenly(input_list: List[Any], max_elements: int = 1000) -> List[Any]:
15
+ return list(sample_evenly_it(input_list, max_elements))