zetavg commited on
Commit
9ee06c7
β€’
1 Parent(s): 49ce4b9

let finetune ui support showing training progress

Browse files
llama_lora/globals.py CHANGED
@@ -17,6 +17,9 @@ class Global:
17
  # Functions
18
  train_fn: Any = None
19
 
 
 
 
20
  # UI related
21
  ui_title: str = "LLaMA-LoRA"
22
  ui_emoji: str = "πŸ¦™πŸŽ›οΈ"
 
17
  # Functions
18
  train_fn: Any = None
19
 
20
+ # Training Control
21
+ should_stop_training = False
22
+
23
  # UI related
24
  ui_title: str = "LLaMA-LoRA"
25
  ui_emoji: str = "πŸ¦™πŸŽ›οΈ"
llama_lora/ui/finetune_ui.py CHANGED
@@ -5,6 +5,8 @@ from datetime import datetime
5
  import gradio as gr
6
  from random_word import RandomWords
7
 
 
 
8
  from ..globals import Global
9
  from ..models import get_base_model, get_tokenizer
10
  from ..utils.data import (
@@ -331,6 +333,31 @@ Train data (first 10):
331
  time.sleep(2)
332
  return message
333
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  return Global.train_fn(
335
  get_base_model(), # base_model
336
  get_tokenizer(), # tokenizer
@@ -351,11 +378,16 @@ Train data (first 10):
351
  True, # train_on_inputs
352
  False, # group_by_length
353
  None, # resume_from_checkpoint
 
354
  )
355
  except Exception as e:
356
  raise gr.Error(e)
357
 
358
 
 
 
 
 
359
  def finetune_ui():
360
  with gr.Blocks() as finetune_ui_blocks:
361
  with gr.Column(elem_id="finetune_ui_content"):
@@ -580,7 +612,10 @@ def finetune_ui():
580
 
581
  # controlled by JS, shows the confirm_abort_button
582
  abort_button.click(None, None, None, None)
583
- confirm_abort_button.click(None, None, None, cancels=[train_progress])
 
 
 
584
 
585
  finetune_ui_blocks.load(_js="""
586
  function finetune_ui_blocks_js() {
 
5
  import gradio as gr
6
  from random_word import RandomWords
7
 
8
+ from transformers import TrainerCallback
9
+
10
  from ..globals import Global
11
  from ..models import get_base_model, get_tokenizer
12
  from ..utils.data import (
 
333
  time.sleep(2)
334
  return message
335
 
336
+ class UiTrainerCallback(TrainerCallback):
337
+ def on_epoch_begin(self, args, state, control, **kwargs):
338
+ if Global.should_stop_training:
339
+ control.should_training_stop = True
340
+ total_steps = (
341
+ state.max_steps if state.max_steps is not None else state.num_train_epochs * state.steps_per_epoch)
342
+ progress(
343
+ (state.global_step, total_steps),
344
+ desc=f"Training... (Epoch {state.epoch}/{epochs}, Step {state.global_step}/{total_steps})"
345
+ )
346
+
347
+ def on_step_end(self, args, state, control, **kwargs):
348
+ if Global.should_stop_training:
349
+ control.should_training_stop = True
350
+ total_steps = (
351
+ state.max_steps if state.max_steps is not None else state.num_train_epochs * state.steps_per_epoch)
352
+ progress(
353
+ (state.global_step, total_steps),
354
+ desc=f"Training... (Epoch {state.epoch}/{epochs}, Step {state.global_step}/{total_steps})"
355
+ )
356
+
357
+ training_callbacks = [UiTrainerCallback]
358
+
359
+ Global.should_stop_training = False
360
+
361
  return Global.train_fn(
362
  get_base_model(), # base_model
363
  get_tokenizer(), # tokenizer
 
378
  True, # train_on_inputs
379
  False, # group_by_length
380
  None, # resume_from_checkpoint
381
+ training_callbacks # callbacks
382
  )
383
  except Exception as e:
384
  raise gr.Error(e)
385
 
386
 
387
+ def do_abort_training():
388
+ Global.should_stop_training = True
389
+
390
+
391
  def finetune_ui():
392
  with gr.Blocks() as finetune_ui_blocks:
393
  with gr.Column(elem_id="finetune_ui_content"):
 
612
 
613
  # controlled by JS, shows the confirm_abort_button
614
  abort_button.click(None, None, None, None)
615
+ confirm_abort_button.click(
616
+ fn=do_abort_training,
617
+ inputs=None, outputs=None,
618
+ cancels=[train_progress])
619
 
620
  finetune_ui_blocks.load(_js="""
621
  function finetune_ui_blocks_js() {