zetavg commited on
Commit
0054cc5
1 Parent(s): db1ee85

support switching to custom tokenizer

Browse files
llama_lora/globals.py CHANGED
@@ -18,6 +18,7 @@ class Global:
18
 
19
  default_base_model_name: str = ""
20
  base_model_name: str = ""
 
21
  base_model_choices: List[str] = []
22
 
23
  trust_remote_code = False
 
18
 
19
  default_base_model_name: str = ""
20
  base_model_name: str = ""
21
+ tokenizer_name = None
22
  base_model_choices: List[str] = []
23
 
24
  trust_remote_code = False
llama_lora/lib/csv_logger.py CHANGED
@@ -25,8 +25,8 @@ class CSVLogger(FlaggingCallback):
25
 
26
  def setup(
27
  self,
28
- components: List[Any],
29
- flagging_dir: str | Path,
30
  ):
31
  self.components = components
32
  self.flagging_dir = flagging_dir
@@ -36,7 +36,7 @@ class CSVLogger(FlaggingCallback):
36
  self,
37
  flag_data: List[Any],
38
  flag_option: str = "",
39
- username: str | None = None,
40
  filename="log.csv",
41
  ) -> int:
42
  flagging_dir = self.flagging_dir
 
25
 
26
  def setup(
27
  self,
28
+ components,
29
+ flagging_dir,
30
  ):
31
  self.components = components
32
  self.flagging_dir = flagging_dir
 
36
  self,
37
  flag_data: List[Any],
38
  flag_option: str = "",
39
+ username=None,
40
  filename="log.csv",
41
  ) -> int:
42
  flagging_dir = self.flagging_dir
llama_lora/ui/finetune_ui.py CHANGED
@@ -306,6 +306,7 @@ def do_train(
306
  ):
307
  try:
308
  base_model_name = Global.base_model_name
 
309
 
310
  resume_from_checkpoint = None
311
  if continue_from_model == "-" or continue_from_model == "None":
@@ -445,7 +446,7 @@ Train data (first 10):
445
  Global.should_stop_training = False
446
 
447
  base_model = get_new_base_model(base_model_name)
448
- tokenizer = get_tokenizer(base_model_name)
449
 
450
  # Do not let other tqdm iterations interfere the progress reporting after training starts.
451
  # progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
 
306
  ):
307
  try:
308
  base_model_name = Global.base_model_name
309
+ tokenizer_name = Global.tokenizer_name or Global.base_model_name
310
 
311
  resume_from_checkpoint = None
312
  if continue_from_model == "-" or continue_from_model == "None":
 
446
  Global.should_stop_training = False
447
 
448
  base_model = get_new_base_model(base_model_name)
449
+ tokenizer = get_tokenizer(tokenizer_name)
450
 
451
  # Do not let other tqdm iterations interfere the progress reporting after training starts.
452
  # progress.track_tqdm = False # setting this dynamically is not working, determining if track_tqdm should be enabled based on GPU cores at start instead.
llama_lora/ui/inference_ui.py CHANGED
@@ -33,9 +33,10 @@ class LoggingItem:
33
 
34
  def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
35
  base_model_name = Global.base_model_name
 
36
 
37
  try:
38
- get_tokenizer(base_model_name)
39
  get_model(base_model_name, lora_model_name)
40
  return ("", "", gr.Textbox.update(visible=False))
41
 
 
33
 
34
  def prepare_inference(lora_model_name, progress=gr.Progress(track_tqdm=True)):
35
  base_model_name = Global.base_model_name
36
+ tokenizer_name = Global.tokenizer_name or Global.base_model_name
37
 
38
  try:
39
+ get_tokenizer(tokenizer_name)
40
  get_model(base_model_name, lora_model_name)
41
  return ("", "", gr.Textbox.update(visible=False))
42
 
llama_lora/ui/main_page.py CHANGED
@@ -25,13 +25,29 @@ def main_page():
25
  """,
26
  elem_id="page_title",
27
  )
28
- global_base_model_select = gr.Dropdown(
29
- label="Base Model",
30
- elem_id="global_base_model_select",
31
- choices=Global.base_model_choices,
32
- value=lambda: Global.base_model_name,
33
- allow_custom_value=True,
34
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  # global_base_model_select_loading_status = gr.Markdown("", elem_id="global_base_model_select_loading_status")
36
 
37
  with gr.Column(elem_id="main_page_tabs_container") as main_page_tabs_container:
@@ -41,13 +57,17 @@ def main_page():
41
  finetune_ui()
42
  with gr.Tab("Tokenizer"):
43
  tokenizer_ui()
44
- please_select_a_base_model_message = gr.Markdown("Please select a base model.", visible=False)
45
- current_base_model_hint = gr.Markdown(lambda: Global.base_model_name, elem_id="current_base_model_hint")
 
 
 
 
46
  foot_info = gr.Markdown(get_foot_info)
47
 
48
  global_base_model_select.change(
49
  fn=pre_handle_change_base_model,
50
- inputs=[],
51
  outputs=[main_page_tabs_container]
52
  ).then(
53
  fn=handle_change_base_model,
@@ -56,11 +76,27 @@ def main_page():
56
  main_page_tabs_container,
57
  please_select_a_base_model_message,
58
  current_base_model_hint,
 
59
  # global_base_model_select_loading_status,
60
  foot_info
61
  ]
62
  )
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  main_page_blocks.load(_js=f"""
65
  function () {{
66
  {popperjs_core_code()}
@@ -95,6 +131,15 @@ def main_page():
95
  const base_model_name = current_base_model_hint_elem.innerText;
96
  document.querySelector('#global_base_model_select input').value = base_model_name;
97
  document.querySelector('#global_base_model_select').classList.add('show');
 
 
 
 
 
 
 
 
 
98
  }, 3200);
99
  """ + """
100
  }
@@ -209,13 +254,21 @@ def main_page_custom_css():
209
  #page_title {
210
  flex-grow: 3;
211
  }
212
- #global_base_model_select {
 
 
213
  position: relative;
214
  align-self: center;
215
- min-width: 250px;
 
 
 
 
216
  padding: 2px 2px;
217
  border: 0;
218
  box-shadow: none;
 
 
219
  opacity: 0;
220
  pointer-events: none;
221
  }
@@ -223,10 +276,12 @@ def main_page_custom_css():
223
  opacity: 1;
224
  pointer-events: auto;
225
  }
226
- #global_base_model_select label .wrap-inner {
 
227
  padding: 2px 8px;
228
  }
229
- #global_base_model_select label span {
 
230
  margin-bottom: 2px;
231
  font-size: 80%;
232
  position: absolute;
@@ -234,9 +289,28 @@ def main_page_custom_css():
234
  left: 8px;
235
  opacity: 0;
236
  }
237
- #global_base_model_select:hover label span {
 
 
238
  opacity: 1;
239
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  #global_base_model_select_loading_status {
242
  position: absolute;
@@ -260,7 +334,7 @@ def main_page_custom_css():
260
  background: var(--block-background-fill);
261
  }
262
 
263
- #current_base_model_hint {
264
  display: none;
265
  }
266
 
@@ -754,24 +828,61 @@ def main_page_custom_css():
754
  return css
755
 
756
 
757
- def pre_handle_change_base_model():
758
- return gr.Column.update(visible=False)
 
 
 
 
759
 
760
 
761
  def handle_change_base_model(selected_base_model_name):
762
  Global.base_model_name = selected_base_model_name
 
763
 
 
764
  if Global.base_model_name:
765
- return gr.Column.update(visible=True), gr.Markdown.update(visible=False), Global.base_model_name, get_foot_info()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
766
 
767
- return gr.Column.update(visible=False), gr.Markdown.update(visible=True), Global.base_model_name, get_foot_info()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
768
 
769
 
770
  def get_foot_info():
771
  info = []
772
  if Global.version:
773
  info.append(f"LLaMA-LoRA Tuner `{Global.version}`")
774
- info.append(f"Base model: `{Global.base_model_name}`")
 
 
 
775
  if Global.ui_show_sys_info:
776
  info.append(f"Data dir: `{Global.data_dir}`")
777
  return f"""\
 
25
  """,
26
  elem_id="page_title",
27
  )
28
+ with gr.Column(elem_id="global_base_model_select_group"):
29
+ global_base_model_select = gr.Dropdown(
30
+ label="Base Model",
31
+ elem_id="global_base_model_select",
32
+ choices=Global.base_model_choices,
33
+ value=lambda: Global.base_model_name,
34
+ allow_custom_value=True,
35
+ )
36
+ use_custom_tokenizer_btn = gr.Button(
37
+ "Use custom tokenizer",
38
+ elem_id="use_custom_tokenizer_btn")
39
+ global_tokenizer_select = gr.Dropdown(
40
+ label="Tokenizer",
41
+ elem_id="global_tokenizer_select",
42
+ # choices=[],
43
+ value=lambda: Global.base_model_name,
44
+ visible=False,
45
+ allow_custom_value=True,
46
+ )
47
+ use_custom_tokenizer_btn.click(
48
+ fn=lambda: gr.Dropdown.update(visible=True),
49
+ inputs=None,
50
+ outputs=[global_tokenizer_select])
51
  # global_base_model_select_loading_status = gr.Markdown("", elem_id="global_base_model_select_loading_status")
52
 
53
  with gr.Column(elem_id="main_page_tabs_container") as main_page_tabs_container:
 
57
  finetune_ui()
58
  with gr.Tab("Tokenizer"):
59
  tokenizer_ui()
60
+ please_select_a_base_model_message = gr.Markdown(
61
+ "Please select a base model.", visible=False)
62
+ current_base_model_hint = gr.Markdown(
63
+ lambda: Global.base_model_name, elem_id="current_base_model_hint")
64
+ current_tokenizer_hint = gr.Markdown(
65
+ lambda: Global.tokenizer_name, elem_id="current_tokenizer_hint")
66
  foot_info = gr.Markdown(get_foot_info)
67
 
68
  global_base_model_select.change(
69
  fn=pre_handle_change_base_model,
70
+ inputs=[global_base_model_select],
71
  outputs=[main_page_tabs_container]
72
  ).then(
73
  fn=handle_change_base_model,
 
76
  main_page_tabs_container,
77
  please_select_a_base_model_message,
78
  current_base_model_hint,
79
+ current_tokenizer_hint,
80
  # global_base_model_select_loading_status,
81
  foot_info
82
  ]
83
  )
84
 
85
+ global_tokenizer_select.change(
86
+ fn=pre_handle_change_tokenizer,
87
+ inputs=[global_tokenizer_select],
88
+ outputs=[main_page_tabs_container]
89
+ ).then(
90
+ fn=handle_change_tokenizer,
91
+ inputs=[global_tokenizer_select],
92
+ outputs=[
93
+ global_tokenizer_select,
94
+ main_page_tabs_container,
95
+ current_tokenizer_hint,
96
+ foot_info
97
+ ]
98
+ )
99
+
100
  main_page_blocks.load(_js=f"""
101
  function () {{
102
  {popperjs_core_code()}
 
131
  const base_model_name = current_base_model_hint_elem.innerText;
132
  document.querySelector('#global_base_model_select input').value = base_model_name;
133
  document.querySelector('#global_base_model_select').classList.add('show');
134
+
135
+ const current_tokenizer_hint_elem = document.querySelector('#current_tokenizer_hint > p');
136
+ const tokenizer_name = current_tokenizer_hint_elem && current_tokenizer_hint_elem.innerText;
137
+
138
+ if (tokenizer_name && tokenizer_name !== base_model_name) {
139
+ document.querySelector('#global_tokenizer_select input').value = tokenizer_name;
140
+ const btn = document.getElementById('use_custom_tokenizer_btn');
141
+ if (btn) btn.click();
142
+ }
143
  }, 3200);
144
  """ + """
145
  }
 
254
  #page_title {
255
  flex-grow: 3;
256
  }
257
+ #global_base_model_select_group,
258
+ #global_base_model_select,
259
+ #global_tokenizer_select {
260
  position: relative;
261
  align-self: center;
262
+ min-width: 250px !important;
263
+ }
264
+ #global_base_model_select,
265
+ #global_tokenizer_select {
266
+ position: relative;
267
  padding: 2px 2px;
268
  border: 0;
269
  box-shadow: none;
270
+ }
271
+ #global_base_model_select {
272
  opacity: 0;
273
  pointer-events: none;
274
  }
 
276
  opacity: 1;
277
  pointer-events: auto;
278
  }
279
+ #global_base_model_select label .wrap-inner,
280
+ #global_tokenizer_select label .wrap-inner {
281
  padding: 2px 8px;
282
  }
283
+ #global_base_model_select label span,
284
+ #global_tokenizer_select label span {
285
  margin-bottom: 2px;
286
  font-size: 80%;
287
  position: absolute;
 
289
  left: 8px;
290
  opacity: 0;
291
  }
292
+ #global_base_model_select_group:hover label span,
293
+ #global_base_model_select:hover label span,
294
+ #global_tokenizer_select:hover label span {
295
  opacity: 1;
296
  }
297
+ #use_custom_tokenizer_btn {
298
+ position: absolute;
299
+ top: -16px;
300
+ right: 10px;
301
+ border: 0 !important;
302
+ width: auto !important;
303
+ background: transparent !important;
304
+ box-shadow: none !important;
305
+ padding: 0 !important;
306
+ font-weight: 100 !important;
307
+ text-decoration: underline;
308
+ font-size: 12px !important;
309
+ opacity: 0;
310
+ }
311
+ #global_base_model_select_group:hover #use_custom_tokenizer_btn {
312
+ opacity: 0.3;
313
+ }
314
 
315
  #global_base_model_select_loading_status {
316
  position: absolute;
 
334
  background: var(--block-background-fill);
335
  }
336
 
337
+ #current_base_model_hint, #current_tokenizer_hint {
338
  display: none;
339
  }
340
 
 
828
  return css
829
 
830
 
831
+ def pre_handle_change_base_model(selected_base_model_name):
832
+ if Global.base_model_name != selected_base_model_name:
833
+ return gr.Column.update(visible=False)
834
+ if Global.tokenizer_name and Global.tokenizer_name != selected_base_model_name:
835
+ return gr.Column.update(visible=False)
836
+ return gr.Column.update(visible=True)
837
 
838
 
839
  def handle_change_base_model(selected_base_model_name):
840
  Global.base_model_name = selected_base_model_name
841
+ Global.tokenizer_name = selected_base_model_name
842
 
843
+ is_base_model_selected = False
844
  if Global.base_model_name:
845
+ is_base_model_selected = True
846
+
847
+ return (
848
+ gr.Column.update(visible=is_base_model_selected),
849
+ gr.Markdown.update(visible=not is_base_model_selected),
850
+ Global.base_model_name,
851
+ Global.tokenizer_name,
852
+ get_foot_info())
853
+
854
+
855
+ def pre_handle_change_tokenizer(selected_tokenizer_name):
856
+ if Global.tokenizer_name != selected_tokenizer_name:
857
+ return gr.Column.update(visible=False)
858
+ return gr.Column.update(visible=True)
859
+
860
 
861
+ def handle_change_tokenizer(selected_tokenizer_name):
862
+ Global.tokenizer_name = selected_tokenizer_name
863
+
864
+ show_tokenizer_select = True
865
+ if not Global.tokenizer_name:
866
+ show_tokenizer_select = False
867
+ if Global.tokenizer_name == Global.base_model_name:
868
+ show_tokenizer_select = False
869
+
870
+ return (
871
+ gr.Dropdown.update(visible=show_tokenizer_select),
872
+ gr.Column.update(visible=True),
873
+ Global.tokenizer_name,
874
+ get_foot_info()
875
+ )
876
 
877
 
878
  def get_foot_info():
879
  info = []
880
  if Global.version:
881
  info.append(f"LLaMA-LoRA Tuner `{Global.version}`")
882
+ if Global.base_model_name:
883
+ info.append(f"Base model: `{Global.base_model_name}`")
884
+ if Global.tokenizer_name and Global.tokenizer_name != Global.base_model_name:
885
+ info.append(f"Tokenizer: `{Global.tokenizer_name}`")
886
  if Global.ui_show_sys_info:
887
  info.append(f"Data dir: `{Global.data_dir}`")
888
  return f"""\
llama_lora/ui/tokenizer_ui.py CHANGED
@@ -7,12 +7,14 @@ from ..models import get_tokenizer
7
 
8
 
9
  def handle_decode(encoded_tokens_json):
10
- base_model_name = Global.base_model_name
 
 
11
  try:
12
  encoded_tokens = json.loads(encoded_tokens_json)
13
  if Global.ui_dev_mode:
14
  return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
15
- tokenizer = get_tokenizer(base_model_name)
16
  decoded_tokens = tokenizer.decode(encoded_tokens)
17
  return decoded_tokens, gr.Markdown.update("", visible=False)
18
  except Exception as e:
@@ -20,11 +22,13 @@ def handle_decode(encoded_tokens_json):
20
 
21
 
22
  def handle_encode(decoded_tokens):
23
- base_model_name = Global.base_model_name
 
 
24
  try:
25
  if Global.ui_dev_mode:
26
  return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
27
- tokenizer = get_tokenizer(base_model_name)
28
  result = tokenizer(decoded_tokens)
29
  encoded_tokens_json = json.dumps(result['input_ids'], indent=2)
30
  return encoded_tokens_json, gr.Markdown.update("", visible=False)
 
7
 
8
 
9
  def handle_decode(encoded_tokens_json):
10
+ # base_model_name = Global.base_model_name
11
+ tokenizer_name = Global.tokenizer_name or Global.base_model_name
12
+
13
  try:
14
  encoded_tokens = json.loads(encoded_tokens_json)
15
  if Global.ui_dev_mode:
16
  return f"Not actually decoding tokens in UI dev mode.", gr.Markdown.update("", visible=False)
17
+ tokenizer = get_tokenizer(tokenizer_name)
18
  decoded_tokens = tokenizer.decode(encoded_tokens)
19
  return decoded_tokens, gr.Markdown.update("", visible=False)
20
  except Exception as e:
 
22
 
23
 
24
  def handle_encode(decoded_tokens):
25
+ # base_model_name = Global.base_model_name
26
+ tokenizer_name = Global.tokenizer_name or Global.base_model_name
27
+
28
  try:
29
  if Global.ui_dev_mode:
30
  return f"[\"Not actually encoding tokens in UI dev mode.\"]", gr.Markdown.update("", visible=False)
31
+ tokenizer = get_tokenizer(tokenizer_name)
32
  result = tokenizer(decoded_tokens)
33
  encoded_tokens_json = json.dumps(result['input_ids'], indent=2)
34
  return encoded_tokens_json, gr.Markdown.update("", visible=False)