Spaces:
Runtime error
Runtime error
zetavg
commited on
Commit
β’
9ee06c7
1
Parent(s):
49ce4b9
let finetune ui support showing training progress
Browse files- llama_lora/globals.py +3 -0
- llama_lora/ui/finetune_ui.py +36 -1
llama_lora/globals.py
CHANGED
@@ -17,6 +17,9 @@ class Global:
|
|
17 |
# Functions
|
18 |
train_fn: Any = None
|
19 |
|
|
|
|
|
|
|
20 |
# UI related
|
21 |
ui_title: str = "LLaMA-LoRA"
|
22 |
ui_emoji: str = "π¦ποΈ"
|
|
|
17 |
# Functions
|
18 |
train_fn: Any = None
|
19 |
|
20 |
+
# Training Control
|
21 |
+
should_stop_training = False
|
22 |
+
|
23 |
# UI related
|
24 |
ui_title: str = "LLaMA-LoRA"
|
25 |
ui_emoji: str = "π¦ποΈ"
|
llama_lora/ui/finetune_ui.py
CHANGED
@@ -5,6 +5,8 @@ from datetime import datetime
|
|
5 |
import gradio as gr
|
6 |
from random_word import RandomWords
|
7 |
|
|
|
|
|
8 |
from ..globals import Global
|
9 |
from ..models import get_base_model, get_tokenizer
|
10 |
from ..utils.data import (
|
@@ -331,6 +333,31 @@ Train data (first 10):
|
|
331 |
time.sleep(2)
|
332 |
return message
|
333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
return Global.train_fn(
|
335 |
get_base_model(), # base_model
|
336 |
get_tokenizer(), # tokenizer
|
@@ -351,11 +378,16 @@ Train data (first 10):
|
|
351 |
True, # train_on_inputs
|
352 |
False, # group_by_length
|
353 |
None, # resume_from_checkpoint
|
|
|
354 |
)
|
355 |
except Exception as e:
|
356 |
raise gr.Error(e)
|
357 |
|
358 |
|
|
|
|
|
|
|
|
|
359 |
def finetune_ui():
|
360 |
with gr.Blocks() as finetune_ui_blocks:
|
361 |
with gr.Column(elem_id="finetune_ui_content"):
|
@@ -580,7 +612,10 @@ def finetune_ui():
|
|
580 |
|
581 |
# controlled by JS, shows the confirm_abort_button
|
582 |
abort_button.click(None, None, None, None)
|
583 |
-
confirm_abort_button.click(
|
|
|
|
|
|
|
584 |
|
585 |
finetune_ui_blocks.load(_js="""
|
586 |
function finetune_ui_blocks_js() {
|
|
|
5 |
import gradio as gr
|
6 |
from random_word import RandomWords
|
7 |
|
8 |
+
from transformers import TrainerCallback
|
9 |
+
|
10 |
from ..globals import Global
|
11 |
from ..models import get_base_model, get_tokenizer
|
12 |
from ..utils.data import (
|
|
|
333 |
time.sleep(2)
|
334 |
return message
|
335 |
|
336 |
+
class UiTrainerCallback(TrainerCallback):
|
337 |
+
def on_epoch_begin(self, args, state, control, **kwargs):
|
338 |
+
if Global.should_stop_training:
|
339 |
+
control.should_training_stop = True
|
340 |
+
total_steps = (
|
341 |
+
state.max_steps if state.max_steps is not None else state.num_train_epochs * state.steps_per_epoch)
|
342 |
+
progress(
|
343 |
+
(state.global_step, total_steps),
|
344 |
+
desc=f"Training... (Epoch {state.epoch}/{epochs}, Step {state.global_step}/{total_steps})"
|
345 |
+
)
|
346 |
+
|
347 |
+
def on_step_end(self, args, state, control, **kwargs):
|
348 |
+
if Global.should_stop_training:
|
349 |
+
control.should_training_stop = True
|
350 |
+
total_steps = (
|
351 |
+
state.max_steps if state.max_steps is not None else state.num_train_epochs * state.steps_per_epoch)
|
352 |
+
progress(
|
353 |
+
(state.global_step, total_steps),
|
354 |
+
desc=f"Training... (Epoch {state.epoch}/{epochs}, Step {state.global_step}/{total_steps})"
|
355 |
+
)
|
356 |
+
|
357 |
+
training_callbacks = [UiTrainerCallback]
|
358 |
+
|
359 |
+
Global.should_stop_training = False
|
360 |
+
|
361 |
return Global.train_fn(
|
362 |
get_base_model(), # base_model
|
363 |
get_tokenizer(), # tokenizer
|
|
|
378 |
True, # train_on_inputs
|
379 |
False, # group_by_length
|
380 |
None, # resume_from_checkpoint
|
381 |
+
training_callbacks # callbacks
|
382 |
)
|
383 |
except Exception as e:
|
384 |
raise gr.Error(e)
|
385 |
|
386 |
|
387 |
+
def do_abort_training():
|
388 |
+
Global.should_stop_training = True
|
389 |
+
|
390 |
+
|
391 |
def finetune_ui():
|
392 |
with gr.Blocks() as finetune_ui_blocks:
|
393 |
with gr.Column(elem_id="finetune_ui_content"):
|
|
|
612 |
|
613 |
# controlled by JS, shows the confirm_abort_button
|
614 |
abort_button.click(None, None, None, None)
|
615 |
+
confirm_abort_button.click(
|
616 |
+
fn=do_abort_training,
|
617 |
+
inputs=None, outputs=None,
|
618 |
+
cancels=[train_progress])
|
619 |
|
620 |
finetune_ui_blocks.load(_js="""
|
621 |
function finetune_ui_blocks_js() {
|