Glavin001 commited on
Commit
5b67ea9
1 Parent(s): 2f586d1

Add training callback to send predictions to WandB table (#521)

Browse files

* WIP Add training callback to send predictions to WandB table

* WIP improve wandb table reporting callback

* WIP improve wandb table reporting callback (cont)

* Add VSCode launching for debugging

* Add tiny llama example

* WIP attempt to improve post-eval prediction generation for table

* WIP attempt to improve post-eval prediction generation for table - part 2

* WIP batch generation

* WIP attempt to handle sample_packing using position_ids for wandb prediction table

* WIP add code for debugging

* Fix sample_packing support for wandb prediction table

* Clean up code for PR review

* Add eval_table_size, eval_table_max_new_tokens configs & clean up code

* Clean up PR, delete VSCode config, add tiny-llama example

* Add eval_table_size, eval_table_max_new_tokens documentation. Fix linting/formatting

README.md CHANGED
@@ -534,6 +534,9 @@ eval_steps: # leave empty to eval at each epoch
534
  save_total_limit: # checkpoints saved at a time
535
  max_steps:
536
 
 
 
 
537
  # save model as safetensors (require safetensors package)
538
  save_safetensors:
539
 
 
534
  save_total_limit: # checkpoints saved at a time
535
  max_steps:
536
 
537
+ eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
538
+ eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128
539
+
540
  # save model as safetensors (require safetensors package)
541
  save_safetensors:
542
 
examples/llama-2/lora.yml CHANGED
@@ -56,6 +56,8 @@ flash_attention: true
56
 
57
  warmup_steps: 10
58
  eval_steps: 20
 
 
59
  save_steps:
60
  debug:
61
  deepspeed:
 
56
 
57
  warmup_steps: 10
58
  eval_steps: 20
59
+ eval_table_size: 5
60
+ eval_table_max_new_tokens: 128
61
  save_steps:
62
  debug:
63
  deepspeed:
examples/llama-2/qlora.yml CHANGED
@@ -58,6 +58,7 @@ flash_attention: true
58
 
59
  warmup_steps: 10
60
  eval_steps: 20
 
61
  save_steps:
62
  debug:
63
  deepspeed:
 
58
 
59
  warmup_steps: 10
60
  eval_steps: 20
61
+ eval_table_size: 5
62
  save_steps:
63
  debug:
64
  deepspeed:
examples/llama-2/tiny-llama.yml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: PY007/TinyLlama-1.1B-step-50K-105b
2
+ base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
3
+
4
+ model_type: LlamaForCausalLM
5
+ tokenizer_type: LlamaTokenizer
6
+ is_llama_derived_model: true
7
+
8
+ load_in_8bit: true
9
+ load_in_4bit: false
10
+ strict: false
11
+
12
+ datasets:
13
+ - path: mhenrichsen/alpaca_2k_test
14
+ type: alpaca
15
+ dataset_prepared_path: last_run_prepared
16
+ val_set_size: 0.01
17
+ output_dir: ./lora-out
18
+
19
+ sequence_len: 4096
20
+ sample_packing: true
21
+
22
+ adapter: lora
23
+ lora_model_dir:
24
+ lora_r: 32
25
+ lora_alpha: 16
26
+ lora_dropout: 0.05
27
+ lora_target_linear: true
28
+ lora_fan_in_fan_out:
29
+
30
+ wandb_project:
31
+ wandb_entity:
32
+ wandb_watch:
33
+ wandb_run_id:
34
+ wandb_log_model:
35
+
36
+ gradient_accumulation_steps: 4
37
+ micro_batch_size: 2
38
+ num_epochs: 3
39
+ optimizer: adamw_bnb_8bit
40
+ lr_scheduler: cosine
41
+ learning_rate: 0.0002
42
+
43
+ train_on_inputs: false
44
+ group_by_length: false
45
+ bf16: true
46
+ fp16: false
47
+ tf32: false
48
+
49
+ gradient_checkpointing: true
50
+ early_stopping_patience:
51
+ resume_from_checkpoint:
52
+ local_rank:
53
+ logging_steps: 1
54
+ xformers_attention:
55
+ flash_attention: true
56
+
57
+ warmup_steps: 10
58
+ eval_steps: 20
59
+ eval_table_size: 5
60
+ save_steps:
61
+ debug:
62
+ deepspeed:
63
+ weight_decay: 0.0
64
+ fsdp:
65
+ fsdp_config:
66
+ special_tokens:
67
+ bos_token: "<s>"
68
+ eos_token: "</s>"
69
+ unk_token: "<unk>"
src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -193,7 +193,7 @@ def flashattn_forward(
193
  # only on first autoregressive step q,k,v have same seqlen
194
  is_causal = key_states.shape == query_states.shape
195
 
196
- if cu_seqlens is not None and max_seqlen is not None:
197
  # special handling using sample packing
198
  qkv = torch.stack(
199
  [query_states, key_states, value_states], dim=2
@@ -261,6 +261,8 @@ def flashattn_forward(
261
  if attention_mask is not None
262
  else None,
263
  )
 
 
264
  output_unpad = flash_attn_varlen_kvpacked_func(
265
  q_unpad,
266
  kv_unpad,
 
193
  # only on first autoregressive step q,k,v have same seqlen
194
  is_causal = key_states.shape == query_states.shape
195
 
196
+ if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
197
  # special handling using sample packing
198
  qkv = torch.stack(
199
  [query_states, key_states, value_states], dim=2
 
261
  if attention_mask is not None
262
  else None,
263
  )
264
+ if q_unpad.dtype != kv_unpad.dtype:
265
+ kv_unpad = kv_unpad.to(q_unpad.dtype)
266
  output_unpad = flash_attn_varlen_kvpacked_func(
267
  q_unpad,
268
  kv_unpad,
src/axolotl/utils/callbacks.py CHANGED
@@ -11,10 +11,13 @@ import numpy as np
11
  import pandas as pd
12
  import torch
13
  import torch.distributed as dist
 
14
  from datasets import load_dataset
15
  from optimum.bettertransformer import BetterTransformer
16
  from tqdm import tqdm
17
  from transformers import (
 
 
18
  TrainerCallback,
19
  TrainerControl,
20
  TrainerState,
@@ -323,3 +326,191 @@ def bench_eval_callback_factory(trainer, tokenizer):
323
  metrics[key] = val
324
 
325
  return BenchEvalCallback
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  import pandas as pd
12
  import torch
13
  import torch.distributed as dist
14
+ import wandb
15
  from datasets import load_dataset
16
  from optimum.bettertransformer import BetterTransformer
17
  from tqdm import tqdm
18
  from transformers import (
19
+ GenerationConfig,
20
+ Trainer,
21
  TrainerCallback,
22
  TrainerControl,
23
  TrainerState,
 
326
  metrics[key] = val
327
 
328
  return BenchEvalCallback
329
+
330
+
331
+ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
332
+ class LogPredictionCallback(TrainerCallback):
333
+ """Callback to log prediction values during each evaluation"""
334
+
335
+ def __init__(self, cfg):
336
+ self.cfg = cfg
337
+ self.logged = False
338
+
339
+ def on_evaluate(
340
+ self,
341
+ args: AxolotlTrainingArguments, # pylint: disable=unused-argument
342
+ state: TrainerState,
343
+ control: TrainerControl,
344
+ train_dataloader, # pylint: disable=unused-argument
345
+ eval_dataloader,
346
+ **kwargs, # pylint: disable=unused-argument
347
+ ):
348
+ eval_table_size = self.cfg.eval_table_size
349
+
350
+ if eval_table_size <= 0:
351
+ return control
352
+
353
+ trainer.model.eval()
354
+ device = torch.device(self.cfg.device)
355
+
356
+ # pylint: disable=duplicate-code
357
+ generation_config = GenerationConfig(
358
+ max_new_tokens=self.cfg.eval_table_max_new_tokens,
359
+ bos_token_id=tokenizer.bos_token_id,
360
+ eos_token_id=tokenizer.eos_token_id,
361
+ pad_token_id=tokenizer.pad_token_id,
362
+ do_sample=False,
363
+ use_cache=True,
364
+ return_dict_in_generate=True,
365
+ output_attentions=False,
366
+ output_hidden_states=False,
367
+ output_scores=False,
368
+ )
369
+
370
+ def logits_to_tokens(logits) -> str:
371
+ probabilities = torch.softmax(logits, dim=-1)
372
+ # Get the predicted token ids (the ones with the highest probability)
373
+ predicted_token_ids = torch.argmax(probabilities, dim=-1)
374
+ return predicted_token_ids
375
+
376
+ def find_ranges(lst):
377
+ ranges = []
378
+ start = 0
379
+ for i in range(1, len(lst)):
380
+ if lst[i] == 0:
381
+ ranges.append((start, i - 1))
382
+ start = i
383
+ end = len(lst) - 1
384
+ ranges.append((start, end))
385
+ return ranges
386
+
387
+ def log_table_from_dataloader(name: str, table_dataloader):
388
+ table = wandb.Table(
389
+ columns=[
390
+ "id",
391
+ "Prompt",
392
+ "Correct Completion",
393
+ "Predicted Completion (model.generate)",
394
+ "Predicted Completion (trainer.prediction_step)",
395
+ ]
396
+ )
397
+ row_index = 0
398
+
399
+ for batch in tqdm(table_dataloader):
400
+ if row_index > eval_table_size:
401
+ break
402
+
403
+ batch_labels = batch["labels"].to(device)
404
+ batch_input_ids = batch["input_ids"].to(device)
405
+
406
+ if "position_ids" in batch:
407
+ batch_pos_ids = batch["position_ids"].tolist()
408
+ else:
409
+ batch_pos_ids = [None] * len(batch["input_ids"])
410
+
411
+ (_, batch_logits, _) = trainer.prediction_step(
412
+ trainer.model,
413
+ batch,
414
+ prediction_loss_only=False,
415
+ )
416
+
417
+ prompt_token_ids_list = []
418
+ pred_step_token_ids_list = []
419
+ completion_token_ids_list = []
420
+
421
+ for input_ids_all, labels_all, pos_ids, logits in zip(
422
+ batch_input_ids,
423
+ batch_labels,
424
+ batch_pos_ids,
425
+ batch_logits,
426
+ ):
427
+ if pos_ids is None:
428
+ pos_ranges = [(0, len(input_ids_all) - 1)]
429
+ else:
430
+ pos_ranges = find_ranges(pos_ids)
431
+
432
+ for pos_range in pos_ranges:
433
+ start, end = pos_range
434
+ if start == end:
435
+ continue
436
+
437
+ input_ids = input_ids_all[start : end + 1]
438
+ labels = labels_all[start : end + 1]
439
+
440
+ tokens_without_loss = labels == IGNORE_INDEX
441
+ tokens_with_loss = labels != IGNORE_INDEX
442
+ tokens_exclude_padding = input_ids != tokenizer.pad_token_id
443
+ prompt_token_includes = (
444
+ tokens_without_loss & tokens_exclude_padding
445
+ )
446
+
447
+ prompt_token_ids = input_ids[prompt_token_includes]
448
+ prompt_token_ids_list.append(prompt_token_ids)
449
+
450
+ completion_token_ids = input_ids[tokens_with_loss]
451
+ completion_token_ids_list.append(completion_token_ids)
452
+
453
+ pred_step_token_ids = logits_to_tokens(
454
+ logits[start : end + 1]
455
+ )[tokens_with_loss]
456
+ pred_step_token_ids_list.append(pred_step_token_ids)
457
+
458
+ prompt_texts = tokenizer.batch_decode(
459
+ prompt_token_ids_list, skip_special_tokens=True
460
+ )
461
+ completion_texts = tokenizer.batch_decode(
462
+ completion_token_ids_list, skip_special_tokens=True
463
+ )
464
+ pred_step_texts = tokenizer.batch_decode(
465
+ pred_step_token_ids_list, skip_special_tokens=True
466
+ )
467
+
468
+ with torch.no_grad():
469
+ prompt_encoding = tokenizer(
470
+ prompt_texts, padding=True, return_tensors="pt"
471
+ ).to(self.cfg.device)
472
+ predictions = trainer.model.generate(
473
+ **prompt_encoding, generation_config=generation_config
474
+ )
475
+
476
+ prediction_all_tokens = predictions["sequences"].cpu().tolist()
477
+ prediction_without_prompt_tokens_list = []
478
+ for prompt_token_ids, prediction_tokens in zip(
479
+ prompt_token_ids_list, prediction_all_tokens
480
+ ):
481
+ prediction_without_prompt_tokens = prediction_tokens[
482
+ len(prompt_token_ids) :
483
+ ]
484
+ prediction_without_prompt_tokens_list.append(
485
+ prediction_without_prompt_tokens
486
+ )
487
+
488
+ predicted_texts = tokenizer.batch_decode(
489
+ prediction_without_prompt_tokens_list, skip_special_tokens=True
490
+ )
491
+
492
+ for (
493
+ prompt_text,
494
+ completion_text,
495
+ prediction_text,
496
+ pred_step_text,
497
+ ) in zip(
498
+ prompt_texts, completion_texts, predicted_texts, pred_step_texts
499
+ ):
500
+ table.add_data(
501
+ row_index,
502
+ prompt_text,
503
+ completion_text,
504
+ prediction_text,
505
+ pred_step_text,
506
+ )
507
+ row_index += 1
508
+
509
+ wandb.run.log({f"{name} - Predictions vs Ground Truth": table})
510
+
511
+ if is_main_process():
512
+ log_table_from_dataloader("Eval", eval_dataloader)
513
+
514
+ return control
515
+
516
+ return LogPredictionCallback
src/axolotl/utils/config.py CHANGED
@@ -48,6 +48,8 @@ def normalize_config(cfg):
48
  )
49
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
50
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
 
 
51
  choose_device(cfg)
52
  cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
53
  if cfg.ddp:
 
48
  )
49
  cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
50
  cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
51
+ cfg.eval_table_size = cfg.eval_table_size or 0
52
+ cfg.eval_table_max_new_tokens = cfg.eval_table_max_new_tokens or 128
53
  choose_device(cfg)
54
  cfg.ddp = cfg.ddp if cfg.ddp is not None else cfg.world_size != 1
55
  if cfg.ddp:
src/axolotl/utils/models.py CHANGED
@@ -296,10 +296,10 @@ def load_model(
296
  if (
297
  hasattr(model.config, "max_position_embeddings")
298
  and model.config.max_position_embeddings
299
- and cfg.sequence_len >= model.config.max_position_embeddings
300
  ):
301
  LOG.warning(
302
- f"increasing model.config.max_position_embeddings to {cfg.sequence_len}"
303
  )
304
  model.config.max_position_embeddings = cfg.sequence_len
305
 
 
296
  if (
297
  hasattr(model.config, "max_position_embeddings")
298
  and model.config.max_position_embeddings
299
+ and cfg.sequence_len > model.config.max_position_embeddings
300
  ):
301
  LOG.warning(
302
+ f"increasing model.config.max_position_embeddings from {model.config.max_position_embeddings} to {cfg.sequence_len}"
303
  )
304
  model.config.max_position_embeddings = cfg.sequence_len
305
 
src/axolotl/utils/trainer.py CHANGED
@@ -30,6 +30,7 @@ from axolotl.utils.callbacks import (
30
  SaveBetterTransformerModelCallback,
31
  SavePeftModelCallback,
32
  bench_eval_callback_factory,
 
33
  )
34
  from axolotl.utils.collators import DataCollatorForSeq2Seq
35
  from axolotl.utils.dataloader import MultipackDistributedDataloader
@@ -703,6 +704,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
703
  **trainer_kwargs,
704
  )
705
 
 
 
 
 
706
  if cfg.do_bench_eval:
707
  trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
708
 
 
30
  SaveBetterTransformerModelCallback,
31
  SavePeftModelCallback,
32
  bench_eval_callback_factory,
33
+ log_prediction_callback_factory,
34
  )
35
  from axolotl.utils.collators import DataCollatorForSeq2Seq
36
  from axolotl.utils.dataloader import MultipackDistributedDataloader
 
704
  **trainer_kwargs,
705
  )
706
 
707
+ if cfg.use_wandb and cfg.eval_table_size > 0:
708
+ LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
709
+ trainer.add_callback(LogPredictionCallback(cfg))
710
+
711
  if cfg.do_bench_eval:
712
  trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
713