roborovski winglian commited on
Commit
cf64284
1 Parent(s): c996881

Phi-3 conversation format, example training script and perplexity metric (#1582)

Browse files

* phi-3 support and perplexity metric

* phi-3 chat template

* metrics updates

* chore: lint

* fix assertion on Tensor

* fix tests since tokenization happens in the metric

* fix perplexity value of shorter passage

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

.gitignore CHANGED
@@ -176,3 +176,9 @@ qlora-out/*
176
  mlruns/*
177
 
178
  /.quarto/
 
 
 
 
 
 
 
176
  mlruns/*
177
 
178
  /.quarto/
179
+ prepared-datasets/
180
+ submit.sh
181
+ *.out*
182
+
183
+ typings/
184
+ out/
examples/phi/phi3-ft.yml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: microsoft/Phi-3-mini-4k-instruct
2
+ trust_remote_code: true
3
+ model_type: AutoModelForCausalLM
4
+ tokenizer_type: AutoTokenizer
5
+ chat_template: phi_3
6
+
7
+ load_in_8bit: false
8
+ load_in_4bit: false
9
+ strict: false
10
+
11
+ datasets:
12
+ - path: garage-bAInd/Open-Platypus
13
+ type: alpaca:phi
14
+
15
+ dataset_prepared_path:
16
+ val_set_size: 0.01
17
+ output_dir: ./out
18
+
19
+ sequence_len: 4096
20
+ sample_packing: true
21
+ pad_to_sequence_len: true
22
+
23
+ adapter: lora
24
+ lora_model_dir:
25
+ lora_r: 64
26
+ lora_alpha: 32
27
+ lora_dropout: 0.05
28
+ lora_target_linear: true
29
+ lora_fan_in_fan_out:
30
+
31
+ gradient_accumulation_steps: 1
32
+ micro_batch_size: 2
33
+ num_epochs: 1
34
+ optimizer: adamw_torch
35
+ adam_beta2: 0.95
36
+ adam_epsilon: 0.00001
37
+ max_grad_norm: 1.0
38
+ lr_scheduler: cosine
39
+ learning_rate: 5.0e-6
40
+
41
+ train_on_inputs: false
42
+ group_by_length: false
43
+ bf16: auto
44
+
45
+ gradient_checkpointing: true
46
+ gradient_checkpointing_kwargs:
47
+ use_reentrant: True
48
+ early_stopping_patience: 3
49
+ logging_steps: 1
50
+ flash_attention: true
51
+
52
+ eval_steps: 1000
53
+ save_steps: 5000
54
+ eval_table_size: 2
55
+ eval_batch_size: 2
56
+ eval_sample_packing: false
57
+ eval_max_new_tokens: 32
58
+ eval_causal_lm_metrics: ["perplexity"]
59
+ do_causal_lm_eval: true
60
+
61
+ warmup_ratio: 0.2
62
+ debug: true
63
+ weight_decay: 0.1
64
+ resize_token_embeddings_to_32x: true
src/axolotl/prompters.py CHANGED
@@ -20,6 +20,7 @@ class PromptStyle(Enum):
20
  INSTRUCT = "instruct"
21
  CHAT = "chat"
22
  CHATML = "chatml"
 
23
 
24
 
25
  class Prompter:
@@ -38,9 +39,9 @@ class AlpacaPrompter(Prompter):
38
  system_format: str = "{system}"
39
  turn_format: str
40
  turn_no_input_format: str
41
- prompt_style: Optional[PromptStyle] = None
42
 
43
- def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
44
  self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
45
  self.match_prompt_style()
46
 
@@ -52,16 +53,20 @@ class AlpacaPrompter(Prompter):
52
  "### Instruction:\n{instruction}\n\n### Response:\n"
53
  )
54
  self.system_format = "{system}\n\n"
55
- if self.prompt_style == PromptStyle.CHAT.value:
56
  self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
57
  self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
58
  self.system_format = "SYSTEM: {system}\n"
59
- if self.prompt_style == PromptStyle.CHATML.value:
60
  self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
61
  self.turn_no_input_format = (
62
  "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
63
  )
64
  self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
 
 
 
 
65
 
66
  def _build_result(self, instruction, input_text, output):
67
  # returns the full prompt from instruction and optional input
@@ -381,12 +386,14 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
381
  conversation: Optional[Union[str, Conversation]] = None,
382
  role_key_human: Optional[str] = None,
383
  role_key_model: Optional[str] = None,
 
384
  roles: Optional[dict] = None,
385
  ):
386
  super().__init__(
387
  conversation=conversation,
388
  role_key_human=role_key_human,
389
  role_key_model=role_key_model,
 
390
  roles=roles,
391
  )
392
 
 
20
  INSTRUCT = "instruct"
21
  CHAT = "chat"
22
  CHATML = "chatml"
23
+ PHI = "phi"
24
 
25
 
26
  class Prompter:
 
39
  system_format: str = "{system}"
40
  turn_format: str
41
  turn_no_input_format: str
42
+ prompt_style: Optional[str] = None
43
 
44
+ def __init__(self, prompt_style: Optional[str] = PromptStyle.INSTRUCT.value):
45
  self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
46
  self.match_prompt_style()
47
 
 
53
  "### Instruction:\n{instruction}\n\n### Response:\n"
54
  )
55
  self.system_format = "{system}\n\n"
56
+ elif self.prompt_style == PromptStyle.CHAT.value:
57
  self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
58
  self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
59
  self.system_format = "SYSTEM: {system}\n"
60
+ elif self.prompt_style == PromptStyle.CHATML.value:
61
  self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
62
  self.turn_no_input_format = (
63
  "<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
64
  )
65
  self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
66
+ elif self.prompt_style == PromptStyle.PHI.value:
67
+ self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>"
68
+ self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>"
69
+ self.system_format = "<|system|>{system}\n"
70
 
71
  def _build_result(self, instruction, input_text, output):
72
  # returns the full prompt from instruction and optional input
 
386
  conversation: Optional[Union[str, Conversation]] = None,
387
  role_key_human: Optional[str] = None,
388
  role_key_model: Optional[str] = None,
389
+ role_key_tool: Optional[str] = None,
390
  roles: Optional[dict] = None,
391
  ):
392
  super().__init__(
393
  conversation=conversation,
394
  role_key_human=role_key_human,
395
  role_key_model=role_key_model,
396
+ role_key_tool=role_key_tool,
397
  roles=roles,
398
  )
399
 
src/axolotl/utils/callbacks/__init__.py CHANGED
@@ -5,6 +5,7 @@ from __future__ import annotations
5
  import logging
6
  import math
7
  import os
 
8
  from shutil import copyfile
9
  from tempfile import NamedTemporaryFile
10
  from typing import TYPE_CHECKING, Any, Dict, List
@@ -30,6 +31,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
30
 
31
  from axolotl.utils import is_mlflow_available
32
  from axolotl.utils.bench import log_gpu_memory_usage
 
33
  from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
34
  from axolotl.utils.distributed import (
35
  barrier,
@@ -374,10 +376,14 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
374
  def __maybe_load_metrics(self):
375
  metrics = {}
376
  for metric in self.cfg.eval_causal_lm_metrics:
377
- try:
378
- metrics[metric] = evaluate.load(metric)
379
- except Exception as exc: # pylint: disable=broad-exception-caught
380
- LOG.warning(f"{metric}: {exc.args}")
 
 
 
 
381
  return metrics
382
 
383
  def on_evaluate(
@@ -421,13 +427,20 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
421
  # safely compute a metric and return the score if the format is correct
422
  metric_score = None
423
  try:
424
- metric_score = metric.compute(**kwargs)
 
 
 
 
 
 
425
  return (
426
  metric_score["score"]
427
  if "score" in metric_score
428
  else metric_score["mean_score"]
429
  )
430
  except Exception: # pylint: disable=broad-exception-caught
 
431
  LOG.debug(
432
  f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
433
  )
@@ -443,11 +456,12 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
443
  predictions=predictions,
444
  sources=sources,
445
  )
446
- score = score or compute(
447
- metric,
448
- references=[[r] for r in references],
449
- predictions=predictions,
450
- )
 
451
  scores[metric_name] = score
452
  return scores
453
 
 
5
  import logging
6
  import math
7
  import os
8
+ import traceback
9
  from shutil import copyfile
10
  from tempfile import NamedTemporaryFile
11
  from typing import TYPE_CHECKING, Any, Dict, List
 
31
 
32
  from axolotl.utils import is_mlflow_available
33
  from axolotl.utils.bench import log_gpu_memory_usage
34
+ from axolotl.utils.callbacks.perplexity import Perplexity
35
  from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
36
  from axolotl.utils.distributed import (
37
  barrier,
 
376
  def __maybe_load_metrics(self):
377
  metrics = {}
378
  for metric in self.cfg.eval_causal_lm_metrics:
379
+ if metric == "perplexity":
380
+ max_seq_len = self.cfg.eval_max_new_tokens
381
+ metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len)
382
+ else:
383
+ try:
384
+ metrics[metric] = evaluate.load(metric)
385
+ except Exception as exc: # pylint: disable=broad-exception-caught
386
+ LOG.warning(f"{metric}: {exc.args}")
387
  return metrics
388
 
389
  def on_evaluate(
 
427
  # safely compute a metric and return the score if the format is correct
428
  metric_score = None
429
  try:
430
+ # Only pass the kwargs that are in the metric's feature list
431
+ metric_kwargs = {
432
+ k: kwargs[k]
433
+ for k in metric._feature_names() # pylint: disable=protected-access
434
+ if k in kwargs
435
+ }
436
+ metric_score = metric.compute(**metric_kwargs)
437
  return (
438
  metric_score["score"]
439
  if "score" in metric_score
440
  else metric_score["mean_score"]
441
  )
442
  except Exception: # pylint: disable=broad-exception-caught
443
+ traceback.print_exc()
444
  LOG.debug(
445
  f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
446
  )
 
456
  predictions=predictions,
457
  sources=sources,
458
  )
459
+ if score is None:
460
+ score = compute(
461
+ metric,
462
+ references=[[r] for r in references],
463
+ predictions=predictions,
464
+ )
465
  scores[metric_name] = score
466
  return scores
467
 
src/axolotl/utils/callbacks/perplexity.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """callback to calculate perplexity as an evaluation metric."""
2
+ from typing import Dict, List, Optional
3
+
4
+ import torch
5
+ from torch import Tensor
6
+ from tqdm import tqdm
7
+ from transformers.modeling_outputs import CausalLMOutput
8
+ from transformers.modeling_utils import PreTrainedModel
9
+ from transformers.tokenization_utils import PreTrainedTokenizer
10
+
11
+
12
+ class Perplexity:
13
+ """
14
+ Calculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity.
15
+ This is a custom variant that doesn't re-tokenize the input or re-load the model.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ model: PreTrainedModel,
21
+ tokenizer: PreTrainedTokenizer,
22
+ max_seq_len: int,
23
+ stride: int = 512,
24
+ ) -> None:
25
+ self.max_seq_len = max_seq_len
26
+ self.stride = stride
27
+ self.model = model
28
+ self.tokenizer = tokenizer
29
+ self.device = model.device
30
+ self.name = "perplexity"
31
+
32
+ def _feature_names(self) -> List[str]:
33
+ return ["references"]
34
+
35
+ def compute(
36
+ self,
37
+ references: Optional[List[str]] = None,
38
+ ) -> Dict[str, float]:
39
+ """
40
+ Compute perplexity in a fixed length sliding window across the sequence.
41
+ """
42
+ assert references is not None, "Missing parameter: references"
43
+
44
+ references_tokenized = self.tokenizer(
45
+ references, return_tensors="pt", padding=True, truncation=True
46
+ )
47
+ input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
48
+ input_ids = input_ids.to(self.device)
49
+
50
+ sequence_length = input_ids.size(1)
51
+
52
+ losses = []
53
+ prev_end_loc = 0
54
+ for begin_loc in tqdm(range(0, sequence_length, self.stride)):
55
+ end_loc = min(begin_loc + self.max_seq_len, sequence_length)
56
+ trg_len = end_loc - prev_end_loc
57
+ input_ids_slice = input_ids[:, begin_loc:end_loc]
58
+ labels_slice = input_ids_slice.clone()
59
+ labels_slice[:, :-trg_len] = -100
60
+
61
+ with torch.no_grad():
62
+ outputs: CausalLMOutput = self.model(
63
+ input_ids=input_ids_slice, labels=labels_slice
64
+ )
65
+
66
+ losses.append(outputs.loss)
67
+
68
+ prev_end_loc = end_loc
69
+ if end_loc == sequence_length:
70
+ break
71
+
72
+ perplexity = torch.exp(torch.stack(losses).mean()).item()
73
+
74
+ return {
75
+ "score": perplexity,
76
+ }
src/axolotl/utils/chat_templates.py CHANGED
@@ -25,6 +25,7 @@ def chat_templates(user_choice: str):
25
  "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
26
  "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
27
  "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
 
28
  }
29
 
30
  if user_choice in templates:
 
25
  "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
26
  "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
27
  "llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
28
+ "phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
29
  }
30
 
31
  if user_choice in templates:
src/axolotl/utils/config/__init__.py CHANGED
@@ -10,6 +10,7 @@ from transformers.utils import is_torch_bf16_gpu_available
10
 
11
  from axolotl.utils.bench import log_gpu_memory_usage
12
  from axolotl.utils.config.models.input.v0_4_1 import (
 
13
  AxolotlConfigWCapabilities,
14
  AxolotlInputConfig,
15
  )
@@ -586,13 +587,12 @@ def legacy_validate_config(cfg):
586
  )
587
 
588
  if cfg.eval_causal_lm_metrics:
589
- supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
590
  if not isinstance(cfg.eval_causal_lm_metrics, list):
591
  raise ValueError("eval_causal_lm_metrics must be a list")
592
  # only ["sacrebleu", "comet", "ter", "chrf"] supported
593
- if set(cfg.eval_causal_lm_metrics) - set(supported_metrics):
594
  raise ValueError(
595
- f"eval_causal_lm_metrics must be one of {supported_metrics}"
596
  )
597
 
598
  # TODO
 
10
 
11
  from axolotl.utils.bench import log_gpu_memory_usage
12
  from axolotl.utils.config.models.input.v0_4_1 import (
13
+ SUPPORTED_METRICS,
14
  AxolotlConfigWCapabilities,
15
  AxolotlInputConfig,
16
  )
 
587
  )
588
 
589
  if cfg.eval_causal_lm_metrics:
 
590
  if not isinstance(cfg.eval_causal_lm_metrics, list):
591
  raise ValueError("eval_causal_lm_metrics must be a list")
592
  # only ["sacrebleu", "comet", "ter", "chrf"] supported
593
+ if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
594
  raise ValueError(
595
+ f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
596
  )
597
 
598
  # TODO
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -17,6 +17,8 @@ from axolotl.utils.config.models.internals import GPUCapabilities
17
 
18
  LOG = logging.getLogger("axolotl.utils.config.models.input")
19
 
 
 
20
 
21
  class DeprecatedParameters(BaseModel):
22
  """configurations that are deprecated"""
@@ -176,6 +178,7 @@ class ChatTemplate(str, Enum):
176
  gemma = "gemma" # pylint: disable=invalid-name
177
  cohere = "cohere" # pylint: disable=invalid-name
178
  llama3 = "llama3" # pylint: disable=invalid-name
 
179
 
180
 
181
  class LoftQConfig(BaseModel):
@@ -1073,13 +1076,12 @@ class AxolotlInputConfig(
1073
  )
1074
 
1075
  if data.get("eval_causal_lm_metrics"):
1076
- supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
1077
  if not isinstance(data.get("eval_causal_lm_metrics"), list):
1078
  raise ValueError("eval_causal_lm_metrics must be a list")
1079
  # only ["sacrebleu", "comet", "ter", "chrf"] supported
1080
- if set(data.get("eval_causal_lm_metrics")) - set(supported_metrics):
1081
  raise ValueError(
1082
- f"eval_causal_lm_metrics must be one of {supported_metrics}"
1083
  )
1084
  return data
1085
 
 
17
 
18
  LOG = logging.getLogger("axolotl.utils.config.models.input")
19
 
20
+ SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
21
+
22
 
23
  class DeprecatedParameters(BaseModel):
24
  """configurations that are deprecated"""
 
178
  gemma = "gemma" # pylint: disable=invalid-name
179
  cohere = "cohere" # pylint: disable=invalid-name
180
  llama3 = "llama3" # pylint: disable=invalid-name
181
+ phi_3 = "phi_3" # pylint: disable=invalid-name
182
 
183
 
184
  class LoftQConfig(BaseModel):
 
1076
  )
1077
 
1078
  if data.get("eval_causal_lm_metrics"):
 
1079
  if not isinstance(data.get("eval_causal_lm_metrics"), list):
1080
  raise ValueError("eval_causal_lm_metrics must be a list")
1081
  # only ["sacrebleu", "comet", "ter", "chrf"] supported
1082
+ if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS:
1083
  raise ValueError(
1084
+ f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
1085
  )
1086
  return data
1087
 
src/axolotl/utils/data/sft.py CHANGED
@@ -474,12 +474,16 @@ def load_prepare_datasets(
474
  index=cfg.dataset_shard_idx,
475
  )
476
 
477
- if split == "train" and cfg.val_set_size:
 
 
 
 
478
  # ensure we end up with the same fingerprint by doing rank0 first and being able to cache
479
  to_hash_train = (
480
  dataset._fingerprint # pylint: disable=protected-access
481
  + "|"
482
- + str(cfg.val_set_size)
483
  + "|"
484
  + "train"
485
  + "|"
@@ -488,7 +492,7 @@ def load_prepare_datasets(
488
  to_hash_test = (
489
  dataset._fingerprint # pylint: disable=protected-access
490
  + "|"
491
- + str(cfg.val_set_size)
492
  + "|"
493
  + "test"
494
  + "|"
@@ -498,9 +502,7 @@ def load_prepare_datasets(
498
  test_fingerprint = md5(to_hash_test)
499
 
500
  dataset = dataset.train_test_split(
501
- test_size=int(cfg.val_set_size)
502
- if cfg.val_set_size == int(cfg.val_set_size)
503
- else cfg.val_set_size,
504
  shuffle=False,
505
  seed=cfg.seed or 42,
506
  train_new_fingerprint=train_fingerprint,
@@ -535,6 +537,10 @@ def get_dataset_wrapper(
535
  "keep_in_memory": cfg.dataset_keep_in_memory is True,
536
  }
537
 
 
 
 
 
538
  if (
539
  isinstance(dataset, Dataset)
540
  and "input_ids" in dataset.features
 
474
  index=cfg.dataset_shard_idx,
475
  )
476
 
477
+ val_set_size = (
478
+ int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
479
+ )
480
+
481
+ if split == "train" and val_set_size:
482
  # ensure we end up with the same fingerprint by doing rank0 first and being able to cache
483
  to_hash_train = (
484
  dataset._fingerprint # pylint: disable=protected-access
485
  + "|"
486
+ + str(val_set_size)
487
  + "|"
488
  + "train"
489
  + "|"
 
492
  to_hash_test = (
493
  dataset._fingerprint # pylint: disable=protected-access
494
  + "|"
495
+ + str(val_set_size)
496
  + "|"
497
  + "test"
498
  + "|"
 
502
  test_fingerprint = md5(to_hash_test)
503
 
504
  dataset = dataset.train_test_split(
505
+ test_size=val_set_size,
 
 
506
  shuffle=False,
507
  seed=cfg.seed or 42,
508
  train_new_fingerprint=train_fingerprint,
 
537
  "keep_in_memory": cfg.dataset_keep_in_memory is True,
538
  }
539
 
540
+ LOG.info(
541
+ f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
542
+ )
543
+
544
  if (
545
  isinstance(dataset, Dataset)
546
  and "input_ids" in dataset.features
tests/test_perplexity.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """unit tests for perplexity eval callback"""
2
+ # pylint: disable=redefined-outer-name
3
+
4
+ from pytest import fixture
5
+ from transformers.models.auto.modeling_auto import AutoModelForCausalLM
6
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
7
+
8
+ from axolotl.utils.callbacks.perplexity import Perplexity
9
+
10
+ MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
11
+
12
+
13
+ @fixture()
14
+ def metric(tokenizer):
15
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
16
+
17
+ return Perplexity(model, tokenizer, 512)
18
+
19
+
20
+ @fixture()
21
+ def tokenizer():
22
+ return AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
23
+
24
+
25
+ def test_perplexity_longer_than_stride(metric):
26
+ # taken from https://huggingface.co/datasets/roneneldan/TinyStories
27
+ sample_text = """
28
+ Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong. One day, Beep was driving in the park when he saw a big tree. The tree had many leaves that were falling. Beep liked how the leaves fall and wanted to play with them. Beep drove under the tree and watched the leaves fall on him. He laughed and beeped his horn. Beep played with the falling leaves all day. When it was time to go home, Beep knew he needed more fuel. He went to the fuel place and got more healthy fuel. Now, Beep was ready to go fast and play again the next day. And Beep lived happily ever after.
29
+ One day, a little fish named Fin was swimming near the shore. He saw a big crab and wanted to be friends. "Hi, I am Fin. Do you want to play?" asked the little fish. The crab looked at Fin and said, "No, I don't want to play. I am cold and I don't feel fine." Fin felt sad but wanted to help the crab feel better. He swam away and thought of a plan. He remembered that the sun could make things warm. So, Fin swam to the top of the water and called to the sun, "Please, sun, help my new friend feel fine and not freeze!" The sun heard Fin's call and shone its warm light on the shore. The crab started to feel better and not so cold. He saw Fin and said, "Thank you, little fish, for making me feel fine. I don't feel like I will freeze now. Let's play together!" And so, Fin and the crab played and became good friends.
30
+ """
31
+ result = metric.compute([sample_text])
32
+ ppl = result["score"]
33
+ assert round(ppl, 2) == 5.37
34
+
35
+
36
+ def test_perplexity_short(metric):
37
+ # taken from https://huggingface.co/datasets/roneneldan/TinyStories
38
+ sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun."
39
+ result = metric.compute([sample_text])
40
+ ppl = result["score"]
41
+ assert round(ppl, 2) == 10.02