Spaces:
Runtime error
Runtime error
zetavg
commited on
show loss/epoch chart on finetune ui
Browse files- llama_lora/ui/finetune/finetune_ui.py +16 -3
- llama_lora/ui/finetune/style.css +27 -1
- llama_lora/ui/finetune/training.py +59 -5
- requirements.txt +5 -3
llama_lora/ui/finetune/finetune_ui.py
CHANGED
@@ -28,7 +28,8 @@ from .previewing import (
|
|
28 |
)
|
29 |
from .training import (
|
30 |
do_train,
|
31 |
-
render_training_status
|
|
|
32 |
)
|
33 |
|
34 |
register_css_style('finetune', relative_read_file(__file__, "style.css"))
|
@@ -773,10 +774,15 @@ def finetune_ui():
|
|
773 |
)
|
774 |
|
775 |
train_status = gr.HTML(
|
776 |
-
"
|
777 |
label="Train Output",
|
778 |
elem_id="finetune_training_status")
|
779 |
|
|
|
|
|
|
|
|
|
|
|
780 |
training_indicator = gr.HTML(
|
781 |
"training_indicator", visible=False, elem_id="finetune_training_indicator")
|
782 |
|
@@ -787,7 +793,8 @@ def finetune_ui():
|
|
787 |
continue_from_model,
|
788 |
continue_from_checkpoint,
|
789 |
]),
|
790 |
-
outputs=[train_status, training_indicator
|
|
|
791 |
)
|
792 |
|
793 |
# controlled by JS, shows the confirm_abort_button
|
@@ -803,6 +810,12 @@ def finetune_ui():
|
|
803 |
outputs=[train_status, training_indicator],
|
804 |
every=0.2
|
805 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
806 |
finetune_ui_blocks.load(_js=relative_read_file(__file__, "script.js"))
|
807 |
|
808 |
# things_that_might_timeout.append(training_status_updates)
|
|
|
28 |
)
|
29 |
from .training import (
|
30 |
do_train,
|
31 |
+
render_training_status,
|
32 |
+
render_loss_plot
|
33 |
)
|
34 |
|
35 |
register_css_style('finetune', relative_read_file(__file__, "style.css"))
|
|
|
774 |
)
|
775 |
|
776 |
train_status = gr.HTML(
|
777 |
+
"",
|
778 |
label="Train Output",
|
779 |
elem_id="finetune_training_status")
|
780 |
|
781 |
+
with gr.Column(visible=False, elem_id="finetune_loss_plot_container") as loss_plot_container:
|
782 |
+
loss_plot = gr.Plot(
|
783 |
+
visible=False, show_label=False,
|
784 |
+
elem_id="finetune_loss_plot")
|
785 |
+
|
786 |
training_indicator = gr.HTML(
|
787 |
"training_indicator", visible=False, elem_id="finetune_training_indicator")
|
788 |
|
|
|
793 |
continue_from_model,
|
794 |
continue_from_checkpoint,
|
795 |
]),
|
796 |
+
outputs=[train_status, training_indicator,
|
797 |
+
loss_plot_container, loss_plot]
|
798 |
)
|
799 |
|
800 |
# controlled by JS, shows the confirm_abort_button
|
|
|
810 |
outputs=[train_status, training_indicator],
|
811 |
every=0.2
|
812 |
)
|
813 |
+
loss_plot_updates = finetune_ui_blocks.load(
|
814 |
+
fn=render_loss_plot,
|
815 |
+
inputs=None,
|
816 |
+
outputs=[loss_plot_container, loss_plot],
|
817 |
+
every=10
|
818 |
+
)
|
819 |
finetune_ui_blocks.load(_js=relative_read_file(__file__, "script.js"))
|
820 |
|
821 |
# things_that_might_timeout.append(training_status_updates)
|
llama_lora/ui/finetune/style.css
CHANGED
@@ -255,7 +255,9 @@
|
|
255 |
display: none;
|
256 |
}
|
257 |
|
258 |
-
#finetune_training_status > .wrap
|
|
|
|
|
259 |
border: 0;
|
260 |
background: transparent;
|
261 |
pointer-events: none;
|
@@ -264,6 +266,17 @@
|
|
264 |
left: 0;
|
265 |
right: 0;
|
266 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
#finetune_training_status > .wrap .meta-text-center {
|
268 |
transform: none !important;
|
269 |
}
|
@@ -383,5 +396,18 @@
|
|
383 |
/* background: var(--error-background-fill) !important; */
|
384 |
border: 1px solid var(--error-border-color) !important;
|
385 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
386 |
|
387 |
#finetune_training_indicator { display: none; }
|
|
|
255 |
display: none;
|
256 |
}
|
257 |
|
258 |
+
#finetune_training_status > .wrap,
|
259 |
+
#finetune_loss_plot_container > .wrap,
|
260 |
+
#finetune_loss_plot > .wrap {
|
261 |
border: 0;
|
262 |
background: transparent;
|
263 |
pointer-events: none;
|
|
|
266 |
left: 0;
|
267 |
right: 0;
|
268 |
}
|
269 |
+
#finetune_training_status > .wrap:not(.generating)::after {
|
270 |
+
content: "Refresh the page if this takes too long.";
|
271 |
+
position: absolute;
|
272 |
+
top: 0;
|
273 |
+
left: 0;
|
274 |
+
right: 0;
|
275 |
+
bottom: 0;
|
276 |
+
padding-top: 64px;
|
277 |
+
opacity: 0.5;
|
278 |
+
text-align: center;
|
279 |
+
}
|
280 |
#finetune_training_status > .wrap .meta-text-center {
|
281 |
transform: none !important;
|
282 |
}
|
|
|
396 |
/* background: var(--error-background-fill) !important; */
|
397 |
border: 1px solid var(--error-border-color) !important;
|
398 |
}
|
399 |
+
#finetune_loss_plot {
|
400 |
+
padding: var(--block-padding);
|
401 |
+
}
|
402 |
+
#finetune_loss_plot .altair {
|
403 |
+
overflow: auto !important;
|
404 |
+
}
|
405 |
+
#finetune_loss_plot .altair > * {
|
406 |
+
margin: auto !important;
|
407 |
+
}
|
408 |
+
#finetune_loss_plot .vega-embed summary {
|
409 |
+
border: 0;
|
410 |
+
box-shadow: none;
|
411 |
+
}
|
412 |
|
413 |
#finetune_training_indicator { display: none; }
|
llama_lora/ui/finetune/training.py
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
import os
|
2 |
import json
|
3 |
import time
|
|
|
4 |
import datetime
|
5 |
import pytz
|
6 |
import socket
|
7 |
import threading
|
8 |
import traceback
|
|
|
|
|
9 |
import gradio as gr
|
10 |
|
11 |
from huggingface_hub import try_to_load_from_cache, snapshot_download
|
@@ -71,7 +74,7 @@ def do_train(
|
|
71 |
progress=gr.Progress(track_tqdm=False),
|
72 |
):
|
73 |
if Global.is_training or Global.is_train_starting:
|
74 |
-
return render_training_status()
|
75 |
|
76 |
reset_training_status()
|
77 |
Global.is_train_starting = True
|
@@ -206,6 +209,9 @@ def do_train(
|
|
206 |
print(message)
|
207 |
|
208 |
total_steps = 300
|
|
|
|
|
|
|
209 |
for i in range(300):
|
210 |
if (Global.should_stop_training):
|
211 |
break
|
@@ -213,11 +219,14 @@ def do_train(
|
|
213 |
current_step = i + 1
|
214 |
total_epochs = 3
|
215 |
current_epoch = i / 100
|
216 |
-
log_history = []
|
217 |
|
218 |
if (i > 20):
|
219 |
-
loss =
|
220 |
-
log_history
|
|
|
|
|
|
|
|
|
221 |
|
222 |
update_training_states(
|
223 |
total_steps=total_steps,
|
@@ -295,7 +304,7 @@ def do_train(
|
|
295 |
finally:
|
296 |
Global.is_train_starting = False
|
297 |
|
298 |
-
return render_training_status()
|
299 |
|
300 |
|
301 |
def render_training_status():
|
@@ -411,6 +420,51 @@ def render_training_status():
|
|
411 |
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
|
412 |
|
413 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
def format_time(seconds):
|
415 |
hours, remainder = divmod(seconds, 3600)
|
416 |
minutes, seconds = divmod(remainder, 60)
|
|
|
1 |
import os
|
2 |
import json
|
3 |
import time
|
4 |
+
import math
|
5 |
import datetime
|
6 |
import pytz
|
7 |
import socket
|
8 |
import threading
|
9 |
import traceback
|
10 |
+
import altair as alt
|
11 |
+
import pandas as pd
|
12 |
import gradio as gr
|
13 |
|
14 |
from huggingface_hub import try_to_load_from_cache, snapshot_download
|
|
|
74 |
progress=gr.Progress(track_tqdm=False),
|
75 |
):
|
76 |
if Global.is_training or Global.is_train_starting:
|
77 |
+
return render_training_status() + render_loss_plot()
|
78 |
|
79 |
reset_training_status()
|
80 |
Global.is_train_starting = True
|
|
|
209 |
print(message)
|
210 |
|
211 |
total_steps = 300
|
212 |
+
log_history = []
|
213 |
+
initial_loss = 2
|
214 |
+
loss_decay_rate = 0.8
|
215 |
for i in range(300):
|
216 |
if (Global.should_stop_training):
|
217 |
break
|
|
|
219 |
current_step = i + 1
|
220 |
total_epochs = 3
|
221 |
current_epoch = i / 100
|
|
|
222 |
|
223 |
if (i > 20):
|
224 |
+
loss = initial_loss * math.exp(-loss_decay_rate * current_epoch)
|
225 |
+
log_history.append({
|
226 |
+
'loss': loss,
|
227 |
+
'learning_rate': 0.0001,
|
228 |
+
'epoch': current_epoch
|
229 |
+
})
|
230 |
|
231 |
update_training_states(
|
232 |
total_steps=total_steps,
|
|
|
304 |
finally:
|
305 |
Global.is_train_starting = False
|
306 |
|
307 |
+
return render_training_status() + render_loss_plot()
|
308 |
|
309 |
|
310 |
def render_training_status():
|
|
|
420 |
return (gr.HTML.update(value=html_content), gr.HTML.update(visible=True))
|
421 |
|
422 |
|
423 |
+
def render_loss_plot():
|
424 |
+
if len(Global.training_log_history) <= 2:
|
425 |
+
return (gr.Column.update(visible=False), gr.Plot.update(visible=False))
|
426 |
+
|
427 |
+
training_log_history = Global.training_log_history
|
428 |
+
|
429 |
+
loss_data = [
|
430 |
+
{
|
431 |
+
'type': 'train_loss' if 'loss' in item else 'eval_loss',
|
432 |
+
'loss': item.get('loss') or item.get('eval_loss'),
|
433 |
+
'epoch': item.get('epoch')
|
434 |
+
} for item in training_log_history
|
435 |
+
if ('loss' in item or 'eval_loss' in item)
|
436 |
+
and 'epoch' in item
|
437 |
+
]
|
438 |
+
|
439 |
+
source = pd.DataFrame(loss_data)
|
440 |
+
|
441 |
+
highlight = alt.selection(
|
442 |
+
type='single', # type: ignore
|
443 |
+
on='mouseover', fields=['type'], nearest=True
|
444 |
+
)
|
445 |
+
|
446 |
+
base = alt.Chart(source).encode( # type: ignore
|
447 |
+
x='epoch:Q',
|
448 |
+
y='loss:Q',
|
449 |
+
color='type:N',
|
450 |
+
tooltip=['type:N', 'loss:Q', 'epoch:Q']
|
451 |
+
)
|
452 |
+
|
453 |
+
points = base.mark_circle().encode(
|
454 |
+
opacity=alt.value(0)
|
455 |
+
).add_selection(
|
456 |
+
highlight
|
457 |
+
).properties(
|
458 |
+
width=640
|
459 |
+
)
|
460 |
+
|
461 |
+
lines = base.mark_line().encode(
|
462 |
+
size=alt.condition(~highlight, alt.value(1), alt.value(3))
|
463 |
+
)
|
464 |
+
|
465 |
+
return (gr.Column.update(visible=True), gr.Plot.update(points + lines, visible=True))
|
466 |
+
|
467 |
+
|
468 |
def format_time(seconds):
|
469 |
hours, remainder = divmod(seconds, 3600)
|
470 |
minutes, seconds = divmod(remainder, 60)
|
requirements.txt
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
accelerate
|
|
|
2 |
appdirs
|
3 |
bitsandbytes
|
4 |
black
|
@@ -7,10 +8,11 @@ datasets
|
|
7 |
fire
|
8 |
git+https://github.com/huggingface/peft.git
|
9 |
git+https://github.com/huggingface/transformers.git
|
|
|
10 |
huggingface_hub
|
|
|
11 |
numba
|
12 |
nvidia-ml-py3
|
13 |
-
|
14 |
-
loralib
|
15 |
-
sentencepiece
|
16 |
random-word
|
|
|
|
1 |
accelerate
|
2 |
+
altair
|
3 |
appdirs
|
4 |
bitsandbytes
|
5 |
black
|
|
|
8 |
fire
|
9 |
git+https://github.com/huggingface/peft.git
|
10 |
git+https://github.com/huggingface/transformers.git
|
11 |
+
gradio
|
12 |
huggingface_hub
|
13 |
+
loralib
|
14 |
numba
|
15 |
nvidia-ml-py3
|
16 |
+
pandas
|
|
|
|
|
17 |
random-word
|
18 |
+
sentencepiece
|