winglian commited on
Commit
5159d00
·
1 Parent(s): c0f50d9

fix sharegpt tokenization, refactor tokenization debugging

Browse files
scripts/finetune.py CHANGED
@@ -11,6 +11,8 @@ import yaml
11
  from attrdict import AttrDefault
12
 
13
  # add src to the pythonpath so we don't need to pip install this
 
 
14
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
15
  src_dir = os.path.join(project_root, "src")
16
  sys.path.insert(0, src_dir)
@@ -42,36 +44,6 @@ def choose_device(cfg):
42
  cfg.device_map = {"": cfg.device}
43
 
44
 
45
- def check_dataset_labels(dataset, tokenizer):
46
- from termcolor import colored
47
-
48
- # the dataset is already shuffled, so let's just check the first 5 elements
49
- for idx in range(5):
50
- # Get the input_ids, labels, and attention_mask from the dataset
51
- input_ids = dataset[idx]["input_ids"]
52
- labels = dataset[idx]["labels"]
53
- attention_mask = dataset[idx]["attention_mask"]
54
-
55
- # You can compare the input_ids and labels element-wise
56
- # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
57
- colored_tokens = []
58
- for i, (input_id, label_id, mask) in enumerate(
59
- zip(input_ids, labels, attention_mask)
60
- ):
61
- decoded_input_token = tokenizer.decode(input_id)
62
- # Choose the color based on whether the label has the ignore value or not
63
- color = (
64
- "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
65
- )
66
- colored_token = colored(decoded_input_token, color) + colored(
67
- f"({label_id}, {mask})", "white"
68
- )
69
- colored_tokens.append(colored_token)
70
-
71
- logging.info(" ".join(colored_tokens))
72
- logging.info("\n\n\n")
73
-
74
-
75
  def do_inference(cfg, model, tokenizer):
76
  tokenizer.add_special_tokens({"unk_token": "<unk>"})
77
  tokenizer.add_special_tokens({"bos_token": "<s>"})
@@ -199,8 +171,9 @@ def train(
199
  return
200
 
201
  if cfg.debug:
 
202
  check_dataset_labels(
203
- train_dataset.select([random.randrange(0, len(train_dataset) - 1)]),
204
  tokenizer,
205
  )
206
 
 
11
  from attrdict import AttrDefault
12
 
13
  # add src to the pythonpath so we don't need to pip install this
14
+ from axolotl.utils.tokenization import check_dataset_labels
15
+
16
  project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
17
  src_dir = os.path.join(project_root, "src")
18
  sys.path.insert(0, src_dir)
 
44
  cfg.device_map = {"": cfg.device}
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def do_inference(cfg, model, tokenizer):
48
  tokenizer.add_special_tokens({"unk_token": "<unk>"})
49
  tokenizer.add_special_tokens({"bos_token": "<s>"})
 
171
  return
172
 
173
  if cfg.debug:
174
+ logging.info("check_dataset_labels...")
175
  check_dataset_labels(
176
+ train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]),
177
  tokenizer,
178
  )
179
 
src/axolotl/prompters.py CHANGED
@@ -127,7 +127,7 @@ conv_vicuna_v1_1 = Conversation(
127
 
128
 
129
  class ShareGPTPrompter:
130
- def build_prompt(self, source, tokenizer):
131
  # ignore the system prompt if provided
132
  if source[0]["from"] == "system":
133
  source.pop(0)
@@ -157,13 +157,14 @@ class ShareGPTPrompter:
157
  role = roles[sentence["from"]]
158
  assert role == conv.roles[j % 2]
159
  conv.append_message(role, sentence["value"])
 
160
  conversation = conv.get_prompt()
161
 
162
  # Tokenize conversations
163
  tokenized_result = tokenizer(
164
  conversation,
165
  truncation=True,
166
- max_length=2048, # FIXME
167
  padding=False,
168
  return_tensors=None,
169
  )
@@ -173,7 +174,9 @@ class ShareGPTPrompter:
173
  sep = conv.sep + conv.roles[1] + ": "
174
 
175
  rounds = conversation.split(conv.sep2)
 
176
  cur_len = 1
 
177
  for i, rou in enumerate(rounds):
178
  if rou == "":
179
  break
@@ -182,19 +185,27 @@ class ShareGPTPrompter:
182
  if len(parts) != 2:
183
  break
184
  parts[0] += sep
185
- round_len = len(tokenizer(rou)["input_ids"])
186
- instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2
 
187
  target[cur_len : cur_len + instruction_len] = [
188
  IGNORE_TOKEN_ID
189
  ] * instruction_len
190
 
191
  cur_len += round_len
192
- target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
 
 
 
 
 
 
193
  attention_mask = [
194
  1 if x != tokenizer.pad_token_id else 0
195
  for x in tokenized_result["input_ids"]
196
  ]
197
 
 
198
  return dict(
199
  input_ids=tokenized_result["input_ids"],
200
  labels=target,
 
127
 
128
 
129
  class ShareGPTPrompter:
130
+ def build_prompt(self, source, tokenizer, sequence_len=2048):
131
  # ignore the system prompt if provided
132
  if source[0]["from"] == "system":
133
  source.pop(0)
 
157
  role = roles[sentence["from"]]
158
  assert role == conv.roles[j % 2]
159
  conv.append_message(role, sentence["value"])
160
+ # TODO, this concatenates everything, but doesn't seem to properly add the eos_token_id, as the eos_token gets split up
161
  conversation = conv.get_prompt()
162
 
163
  # Tokenize conversations
164
  tokenized_result = tokenizer(
165
  conversation,
166
  truncation=True,
167
+ max_length=sequence_len, # FIXME
168
  padding=False,
169
  return_tensors=None,
170
  )
 
174
  sep = conv.sep + conv.roles[1] + ": "
175
 
176
  rounds = conversation.split(conv.sep2)
177
+ rounds = [r + conv.sep2 for r in rounds]
178
  cur_len = 1
179
+ target[0] = IGNORE_TOKEN_ID # mask out the bos
180
  for i, rou in enumerate(rounds):
181
  if rou == "":
182
  break
 
185
  if len(parts) != 2:
186
  break
187
  parts[0] += sep
188
+ round_len = len(tokenizer(rou)["input_ids"]) - 1 # -1 ignores the bos_token generated for this
189
+ # we have to strip the initial part, any dangling whitespace creates an additional ghost token
190
+ instruction_len = len(tokenizer(parts[0].strip())["input_ids"]) - 1 # -1 ignores the bos_token generated for this
191
  target[cur_len : cur_len + instruction_len] = [
192
  IGNORE_TOKEN_ID
193
  ] * instruction_len
194
 
195
  cur_len += round_len
196
+ if cur_len >= sequence_len:
197
+ break
198
+
199
+ # Fix: Truncate the target to have the same length as input_ids
200
+ target = target[:len(tokenized_result["input_ids"])]
201
+ # target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
202
+
203
  attention_mask = [
204
  1 if x != tokenizer.pad_token_id else 0
205
  for x in tokenized_result["input_ids"]
206
  ]
207
 
208
+ # TODO truncate len to sequence_len
209
  return dict(
210
  input_ids=tokenized_result["input_ids"],
211
  labels=target,
src/axolotl/utils/models.py CHANGED
@@ -53,7 +53,7 @@ def load_model(
53
  logging.info("patching with xformers attention")
54
  hijack_llama_attention()
55
 
56
- torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,)
57
  try:
58
  if cfg.load_4bit:
59
  from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
@@ -161,11 +161,11 @@ def load_model(
161
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
162
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
163
 
164
- if cfg.special_tokens:
165
- for k, v in cfg.special_tokens.items():
166
- setattr(tokenizer, k, v)
167
 
168
- if load_in_8bit and not cfg.load_4bit:
169
  logging.info("converting model w/ prepare_model_for_int8_training")
170
  model = prepare_model_for_int8_training(model)
171
 
 
53
  logging.info("patching with xformers attention")
54
  hijack_llama_attention()
55
 
56
+ torch_dtype = torch.float16 if cfg.load_in_8bit or cfg.fp16 or cfg.bf16 else torch.float32
57
  try:
58
  if cfg.load_4bit:
59
  from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import (
 
161
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
162
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
163
 
164
+ if cfg.tokens:
165
+ for k, v in cfg.tokens.items():
166
+ tokenizer.add_special_tokens({k: v})
167
 
168
+ if load_in_8bit and cfg.load_4bit:
169
  logging.info("converting model w/ prepare_model_for_int8_training")
170
  model = prepare_model_for_int8_training(model)
171
 
src/axolotl/utils/tokenization.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from termcolor import colored
2
+ import logging
3
+
4
+ def check_dataset_labels(dataset, tokenizer):
5
+ # the dataset is already shuffled, so let's just check the first 5 elements
6
+ for idx in range(5):
7
+ check_example_labels(dataset[idx], tokenizer)
8
+
9
+
10
+ def check_example_labels(example, tokenizer):
11
+ # Get the input_ids, labels, and attention_mask from the dataset
12
+ input_ids = example["input_ids"]
13
+ labels = example["labels"]
14
+ attention_mask =example["attention_mask"]
15
+
16
+ # You can compare the input_ids and labels element-wise
17
+ # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
18
+ colored_tokens = []
19
+ for i, (input_id, label_id, mask) in enumerate(
20
+ zip(input_ids, labels, attention_mask)
21
+ ):
22
+ decoded_input_token = tokenizer.decode(input_id)
23
+ # Choose the color based on whether the label has the ignore value or not
24
+ color = (
25
+ "red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
26
+ )
27
+ colored_token = colored(decoded_input_token, color) + colored(
28
+ f"({label_id}, {mask}, {input_id})", "white"
29
+ )
30
+ colored_tokens.append(colored_token)
31
+
32
+ logging.info(" ".join(colored_tokens))
33
+ logging.info("\n\n\n")
src/axolotl/utils/trainer.py CHANGED
@@ -61,6 +61,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
61
  group_by_length=cfg.group_by_length,
62
  report_to="wandb" if cfg.use_wandb else None,
63
  run_name=cfg.wandb_run_id if cfg.use_wandb else None,
 
 
 
 
 
64
  **training_arguments_kwargs,
65
  )
66
 
 
61
  group_by_length=cfg.group_by_length,
62
  report_to="wandb" if cfg.use_wandb else None,
63
  run_name=cfg.wandb_run_id if cfg.use_wandb else None,
64
+ optim=cfg.optimizer if cfg.optimizer != "adam8bit" else cfg.optimizer,
65
+ lr_scheduler_type=cfg.lr_scheduler if cfg.lr_scheduler else None,
66
+ weight_decay=cfg.weight_decay if cfg.weight_decay else 0.0,
67
+ fsdp=cfg.fsdp.split(" ") if cfg.fsdp else None,
68
+ fsdp_transformer_layer_cls_to_wrap=cfg.fsdp_transformer_layer_cls_to_wrap if cfg.fsdp_transformer_layer_cls_to_wrap else None,
69
  **training_arguments_kwargs,
70
  )
71