Phi-3 conversation format, example training script and perplexity metric (#1582)
Browse files* phi-3 support and perplexity metric
* phi-3 chat template
* metrics updates
* chore: lint
* fix assertion on Tensor
* fix tests since tokenization happens in the metric
* fix perplexity value of shorter passage
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
- .gitignore +6 -0
- examples/phi/phi3-ft.yml +64 -0
- src/axolotl/prompters.py +11 -4
- src/axolotl/utils/callbacks/__init__.py +24 -10
- src/axolotl/utils/callbacks/perplexity.py +76 -0
- src/axolotl/utils/chat_templates.py +1 -0
- src/axolotl/utils/config/__init__.py +3 -3
- src/axolotl/utils/config/models/input/v0_4_1/__init__.py +5 -3
- src/axolotl/utils/data/sft.py +12 -6
- tests/test_perplexity.py +41 -0
.gitignore
CHANGED
@@ -176,3 +176,9 @@ qlora-out/*
|
|
176 |
mlruns/*
|
177 |
|
178 |
/.quarto/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
mlruns/*
|
177 |
|
178 |
/.quarto/
|
179 |
+
prepared-datasets/
|
180 |
+
submit.sh
|
181 |
+
*.out*
|
182 |
+
|
183 |
+
typings/
|
184 |
+
out/
|
examples/phi/phi3-ft.yml
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
base_model: microsoft/Phi-3-mini-4k-instruct
|
2 |
+
trust_remote_code: true
|
3 |
+
model_type: AutoModelForCausalLM
|
4 |
+
tokenizer_type: AutoTokenizer
|
5 |
+
chat_template: phi_3
|
6 |
+
|
7 |
+
load_in_8bit: false
|
8 |
+
load_in_4bit: false
|
9 |
+
strict: false
|
10 |
+
|
11 |
+
datasets:
|
12 |
+
- path: garage-bAInd/Open-Platypus
|
13 |
+
type: alpaca:phi
|
14 |
+
|
15 |
+
dataset_prepared_path:
|
16 |
+
val_set_size: 0.01
|
17 |
+
output_dir: ./out
|
18 |
+
|
19 |
+
sequence_len: 4096
|
20 |
+
sample_packing: true
|
21 |
+
pad_to_sequence_len: true
|
22 |
+
|
23 |
+
adapter: lora
|
24 |
+
lora_model_dir:
|
25 |
+
lora_r: 64
|
26 |
+
lora_alpha: 32
|
27 |
+
lora_dropout: 0.05
|
28 |
+
lora_target_linear: true
|
29 |
+
lora_fan_in_fan_out:
|
30 |
+
|
31 |
+
gradient_accumulation_steps: 1
|
32 |
+
micro_batch_size: 2
|
33 |
+
num_epochs: 1
|
34 |
+
optimizer: adamw_torch
|
35 |
+
adam_beta2: 0.95
|
36 |
+
adam_epsilon: 0.00001
|
37 |
+
max_grad_norm: 1.0
|
38 |
+
lr_scheduler: cosine
|
39 |
+
learning_rate: 5.0e-6
|
40 |
+
|
41 |
+
train_on_inputs: false
|
42 |
+
group_by_length: false
|
43 |
+
bf16: auto
|
44 |
+
|
45 |
+
gradient_checkpointing: true
|
46 |
+
gradient_checkpointing_kwargs:
|
47 |
+
use_reentrant: True
|
48 |
+
early_stopping_patience: 3
|
49 |
+
logging_steps: 1
|
50 |
+
flash_attention: true
|
51 |
+
|
52 |
+
eval_steps: 1000
|
53 |
+
save_steps: 5000
|
54 |
+
eval_table_size: 2
|
55 |
+
eval_batch_size: 2
|
56 |
+
eval_sample_packing: false
|
57 |
+
eval_max_new_tokens: 32
|
58 |
+
eval_causal_lm_metrics: ["perplexity"]
|
59 |
+
do_causal_lm_eval: true
|
60 |
+
|
61 |
+
warmup_ratio: 0.2
|
62 |
+
debug: true
|
63 |
+
weight_decay: 0.1
|
64 |
+
resize_token_embeddings_to_32x: true
|
src/axolotl/prompters.py
CHANGED
@@ -20,6 +20,7 @@ class PromptStyle(Enum):
|
|
20 |
INSTRUCT = "instruct"
|
21 |
CHAT = "chat"
|
22 |
CHATML = "chatml"
|
|
|
23 |
|
24 |
|
25 |
class Prompter:
|
@@ -38,9 +39,9 @@ class AlpacaPrompter(Prompter):
|
|
38 |
system_format: str = "{system}"
|
39 |
turn_format: str
|
40 |
turn_no_input_format: str
|
41 |
-
prompt_style: Optional[
|
42 |
|
43 |
-
def __init__(self, prompt_style=PromptStyle.INSTRUCT.value):
|
44 |
self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
|
45 |
self.match_prompt_style()
|
46 |
|
@@ -52,16 +53,20 @@ class AlpacaPrompter(Prompter):
|
|
52 |
"### Instruction:\n{instruction}\n\n### Response:\n"
|
53 |
)
|
54 |
self.system_format = "{system}\n\n"
|
55 |
-
|
56 |
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
57 |
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
58 |
self.system_format = "SYSTEM: {system}\n"
|
59 |
-
|
60 |
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
61 |
self.turn_no_input_format = (
|
62 |
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
|
63 |
)
|
64 |
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
|
|
|
|
|
|
|
|
65 |
|
66 |
def _build_result(self, instruction, input_text, output):
|
67 |
# returns the full prompt from instruction and optional input
|
@@ -381,12 +386,14 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
|
381 |
conversation: Optional[Union[str, Conversation]] = None,
|
382 |
role_key_human: Optional[str] = None,
|
383 |
role_key_model: Optional[str] = None,
|
|
|
384 |
roles: Optional[dict] = None,
|
385 |
):
|
386 |
super().__init__(
|
387 |
conversation=conversation,
|
388 |
role_key_human=role_key_human,
|
389 |
role_key_model=role_key_model,
|
|
|
390 |
roles=roles,
|
391 |
)
|
392 |
|
|
|
20 |
INSTRUCT = "instruct"
|
21 |
CHAT = "chat"
|
22 |
CHATML = "chatml"
|
23 |
+
PHI = "phi"
|
24 |
|
25 |
|
26 |
class Prompter:
|
|
|
39 |
system_format: str = "{system}"
|
40 |
turn_format: str
|
41 |
turn_no_input_format: str
|
42 |
+
prompt_style: Optional[str] = None
|
43 |
|
44 |
+
def __init__(self, prompt_style: Optional[str] = PromptStyle.INSTRUCT.value):
|
45 |
self.prompt_style = prompt_style if prompt_style else PromptStyle.INSTRUCT.value
|
46 |
self.match_prompt_style()
|
47 |
|
|
|
53 |
"### Instruction:\n{instruction}\n\n### Response:\n"
|
54 |
)
|
55 |
self.system_format = "{system}\n\n"
|
56 |
+
elif self.prompt_style == PromptStyle.CHAT.value:
|
57 |
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
58 |
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
59 |
self.system_format = "SYSTEM: {system}\n"
|
60 |
+
elif self.prompt_style == PromptStyle.CHATML.value:
|
61 |
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
62 |
self.turn_no_input_format = (
|
63 |
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
|
64 |
)
|
65 |
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
66 |
+
elif self.prompt_style == PromptStyle.PHI.value:
|
67 |
+
self.turn_format = "<|user|>\n{instruction}<|end|>{input}<|assistant|>"
|
68 |
+
self.turn_no_input_format = "<|user|>\n{instruction}<|end|><|assistant|>"
|
69 |
+
self.system_format = "<|system|>{system}\n"
|
70 |
|
71 |
def _build_result(self, instruction, input_text, output):
|
72 |
# returns the full prompt from instruction and optional input
|
|
|
386 |
conversation: Optional[Union[str, Conversation]] = None,
|
387 |
role_key_human: Optional[str] = None,
|
388 |
role_key_model: Optional[str] = None,
|
389 |
+
role_key_tool: Optional[str] = None,
|
390 |
roles: Optional[dict] = None,
|
391 |
):
|
392 |
super().__init__(
|
393 |
conversation=conversation,
|
394 |
role_key_human=role_key_human,
|
395 |
role_key_model=role_key_model,
|
396 |
+
role_key_tool=role_key_tool,
|
397 |
roles=roles,
|
398 |
)
|
399 |
|
src/axolotl/utils/callbacks/__init__.py
CHANGED
@@ -5,6 +5,7 @@ from __future__ import annotations
|
|
5 |
import logging
|
6 |
import math
|
7 |
import os
|
|
|
8 |
from shutil import copyfile
|
9 |
from tempfile import NamedTemporaryFile
|
10 |
from typing import TYPE_CHECKING, Any, Dict, List
|
@@ -30,6 +31,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
|
30 |
|
31 |
from axolotl.utils import is_mlflow_available
|
32 |
from axolotl.utils.bench import log_gpu_memory_usage
|
|
|
33 |
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
34 |
from axolotl.utils.distributed import (
|
35 |
barrier,
|
@@ -374,10 +376,14 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
|
374 |
def __maybe_load_metrics(self):
|
375 |
metrics = {}
|
376 |
for metric in self.cfg.eval_causal_lm_metrics:
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
|
|
|
|
|
|
|
|
381 |
return metrics
|
382 |
|
383 |
def on_evaluate(
|
@@ -421,13 +427,20 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
|
421 |
# safely compute a metric and return the score if the format is correct
|
422 |
metric_score = None
|
423 |
try:
|
424 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
425 |
return (
|
426 |
metric_score["score"]
|
427 |
if "score" in metric_score
|
428 |
else metric_score["mean_score"]
|
429 |
)
|
430 |
except Exception: # pylint: disable=broad-exception-caught
|
|
|
431 |
LOG.debug(
|
432 |
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
|
433 |
)
|
@@ -443,11 +456,12 @@ def causal_lm_bench_eval_callback_factory(trainer: Trainer, tokenizer):
|
|
443 |
predictions=predictions,
|
444 |
sources=sources,
|
445 |
)
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
|
|
451 |
scores[metric_name] = score
|
452 |
return scores
|
453 |
|
|
|
5 |
import logging
|
6 |
import math
|
7 |
import os
|
8 |
+
import traceback
|
9 |
from shutil import copyfile
|
10 |
from tempfile import NamedTemporaryFile
|
11 |
from typing import TYPE_CHECKING, Any, Dict, List
|
|
|
31 |
|
32 |
from axolotl.utils import is_mlflow_available
|
33 |
from axolotl.utils.bench import log_gpu_memory_usage
|
34 |
+
from axolotl.utils.callbacks.perplexity import Perplexity
|
35 |
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
36 |
from axolotl.utils.distributed import (
|
37 |
barrier,
|
|
|
376 |
def __maybe_load_metrics(self):
|
377 |
metrics = {}
|
378 |
for metric in self.cfg.eval_causal_lm_metrics:
|
379 |
+
if metric == "perplexity":
|
380 |
+
max_seq_len = self.cfg.eval_max_new_tokens
|
381 |
+
metrics[metric] = Perplexity(trainer.model, tokenizer, max_seq_len)
|
382 |
+
else:
|
383 |
+
try:
|
384 |
+
metrics[metric] = evaluate.load(metric)
|
385 |
+
except Exception as exc: # pylint: disable=broad-exception-caught
|
386 |
+
LOG.warning(f"{metric}: {exc.args}")
|
387 |
return metrics
|
388 |
|
389 |
def on_evaluate(
|
|
|
427 |
# safely compute a metric and return the score if the format is correct
|
428 |
metric_score = None
|
429 |
try:
|
430 |
+
# Only pass the kwargs that are in the metric's feature list
|
431 |
+
metric_kwargs = {
|
432 |
+
k: kwargs[k]
|
433 |
+
for k in metric._feature_names() # pylint: disable=protected-access
|
434 |
+
if k in kwargs
|
435 |
+
}
|
436 |
+
metric_score = metric.compute(**metric_kwargs)
|
437 |
return (
|
438 |
metric_score["score"]
|
439 |
if "score" in metric_score
|
440 |
else metric_score["mean_score"]
|
441 |
)
|
442 |
except Exception: # pylint: disable=broad-exception-caught
|
443 |
+
traceback.print_exc()
|
444 |
LOG.debug(
|
445 |
f"Failed to compute metric {metric.name} with kwargs {kwargs.keys()}"
|
446 |
)
|
|
|
456 |
predictions=predictions,
|
457 |
sources=sources,
|
458 |
)
|
459 |
+
if score is None:
|
460 |
+
score = compute(
|
461 |
+
metric,
|
462 |
+
references=[[r] for r in references],
|
463 |
+
predictions=predictions,
|
464 |
+
)
|
465 |
scores[metric_name] = score
|
466 |
return scores
|
467 |
|
src/axolotl/utils/callbacks/perplexity.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""callback to calculate perplexity as an evaluation metric."""
|
2 |
+
from typing import Dict, List, Optional
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from torch import Tensor
|
6 |
+
from tqdm import tqdm
|
7 |
+
from transformers.modeling_outputs import CausalLMOutput
|
8 |
+
from transformers.modeling_utils import PreTrainedModel
|
9 |
+
from transformers.tokenization_utils import PreTrainedTokenizer
|
10 |
+
|
11 |
+
|
12 |
+
class Perplexity:
|
13 |
+
"""
|
14 |
+
Calculate perplexity as defined in https://huggingface.co/docs/transformers/en/perplexity.
|
15 |
+
This is a custom variant that doesn't re-tokenize the input or re-load the model.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
model: PreTrainedModel,
|
21 |
+
tokenizer: PreTrainedTokenizer,
|
22 |
+
max_seq_len: int,
|
23 |
+
stride: int = 512,
|
24 |
+
) -> None:
|
25 |
+
self.max_seq_len = max_seq_len
|
26 |
+
self.stride = stride
|
27 |
+
self.model = model
|
28 |
+
self.tokenizer = tokenizer
|
29 |
+
self.device = model.device
|
30 |
+
self.name = "perplexity"
|
31 |
+
|
32 |
+
def _feature_names(self) -> List[str]:
|
33 |
+
return ["references"]
|
34 |
+
|
35 |
+
def compute(
|
36 |
+
self,
|
37 |
+
references: Optional[List[str]] = None,
|
38 |
+
) -> Dict[str, float]:
|
39 |
+
"""
|
40 |
+
Compute perplexity in a fixed length sliding window across the sequence.
|
41 |
+
"""
|
42 |
+
assert references is not None, "Missing parameter: references"
|
43 |
+
|
44 |
+
references_tokenized = self.tokenizer(
|
45 |
+
references, return_tensors="pt", padding=True, truncation=True
|
46 |
+
)
|
47 |
+
input_ids: Tensor = references_tokenized["input_ids"] # type: ignore
|
48 |
+
input_ids = input_ids.to(self.device)
|
49 |
+
|
50 |
+
sequence_length = input_ids.size(1)
|
51 |
+
|
52 |
+
losses = []
|
53 |
+
prev_end_loc = 0
|
54 |
+
for begin_loc in tqdm(range(0, sequence_length, self.stride)):
|
55 |
+
end_loc = min(begin_loc + self.max_seq_len, sequence_length)
|
56 |
+
trg_len = end_loc - prev_end_loc
|
57 |
+
input_ids_slice = input_ids[:, begin_loc:end_loc]
|
58 |
+
labels_slice = input_ids_slice.clone()
|
59 |
+
labels_slice[:, :-trg_len] = -100
|
60 |
+
|
61 |
+
with torch.no_grad():
|
62 |
+
outputs: CausalLMOutput = self.model(
|
63 |
+
input_ids=input_ids_slice, labels=labels_slice
|
64 |
+
)
|
65 |
+
|
66 |
+
losses.append(outputs.loss)
|
67 |
+
|
68 |
+
prev_end_loc = end_loc
|
69 |
+
if end_loc == sequence_length:
|
70 |
+
break
|
71 |
+
|
72 |
+
perplexity = torch.exp(torch.stack(losses).mean()).item()
|
73 |
+
|
74 |
+
return {
|
75 |
+
"score": perplexity,
|
76 |
+
}
|
src/axolotl/utils/chat_templates.py
CHANGED
@@ -25,6 +25,7 @@ def chat_templates(user_choice: str):
|
|
25 |
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
26 |
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
27 |
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
|
|
28 |
}
|
29 |
|
30 |
if user_choice in templates:
|
|
|
25 |
"gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
|
26 |
"cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}",
|
27 |
"llama3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
|
28 |
+
"phi_3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'system') %}{{'<|system|>' + '\n' + message['content'] + '<|end|>' + '\n'}}{% elif (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif message['role'] == 'assistant' %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
|
29 |
}
|
30 |
|
31 |
if user_choice in templates:
|
src/axolotl/utils/config/__init__.py
CHANGED
@@ -10,6 +10,7 @@ from transformers.utils import is_torch_bf16_gpu_available
|
|
10 |
|
11 |
from axolotl.utils.bench import log_gpu_memory_usage
|
12 |
from axolotl.utils.config.models.input.v0_4_1 import (
|
|
|
13 |
AxolotlConfigWCapabilities,
|
14 |
AxolotlInputConfig,
|
15 |
)
|
@@ -586,13 +587,12 @@ def legacy_validate_config(cfg):
|
|
586 |
)
|
587 |
|
588 |
if cfg.eval_causal_lm_metrics:
|
589 |
-
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
|
590 |
if not isinstance(cfg.eval_causal_lm_metrics, list):
|
591 |
raise ValueError("eval_causal_lm_metrics must be a list")
|
592 |
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
593 |
-
if set(cfg.eval_causal_lm_metrics) -
|
594 |
raise ValueError(
|
595 |
-
f"eval_causal_lm_metrics must be one of {
|
596 |
)
|
597 |
|
598 |
# TODO
|
|
|
10 |
|
11 |
from axolotl.utils.bench import log_gpu_memory_usage
|
12 |
from axolotl.utils.config.models.input.v0_4_1 import (
|
13 |
+
SUPPORTED_METRICS,
|
14 |
AxolotlConfigWCapabilities,
|
15 |
AxolotlInputConfig,
|
16 |
)
|
|
|
587 |
)
|
588 |
|
589 |
if cfg.eval_causal_lm_metrics:
|
|
|
590 |
if not isinstance(cfg.eval_causal_lm_metrics, list):
|
591 |
raise ValueError("eval_causal_lm_metrics must be a list")
|
592 |
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
593 |
+
if set(cfg.eval_causal_lm_metrics) - SUPPORTED_METRICS:
|
594 |
raise ValueError(
|
595 |
+
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
|
596 |
)
|
597 |
|
598 |
# TODO
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -17,6 +17,8 @@ from axolotl.utils.config.models.internals import GPUCapabilities
|
|
17 |
|
18 |
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
19 |
|
|
|
|
|
20 |
|
21 |
class DeprecatedParameters(BaseModel):
|
22 |
"""configurations that are deprecated"""
|
@@ -176,6 +178,7 @@ class ChatTemplate(str, Enum):
|
|
176 |
gemma = "gemma" # pylint: disable=invalid-name
|
177 |
cohere = "cohere" # pylint: disable=invalid-name
|
178 |
llama3 = "llama3" # pylint: disable=invalid-name
|
|
|
179 |
|
180 |
|
181 |
class LoftQConfig(BaseModel):
|
@@ -1073,13 +1076,12 @@ class AxolotlInputConfig(
|
|
1073 |
)
|
1074 |
|
1075 |
if data.get("eval_causal_lm_metrics"):
|
1076 |
-
supported_metrics = ["sacrebleu", "comet", "ter", "chrf"]
|
1077 |
if not isinstance(data.get("eval_causal_lm_metrics"), list):
|
1078 |
raise ValueError("eval_causal_lm_metrics must be a list")
|
1079 |
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
1080 |
-
if set(data.get("eval_causal_lm_metrics")) -
|
1081 |
raise ValueError(
|
1082 |
-
f"eval_causal_lm_metrics must be one of {
|
1083 |
)
|
1084 |
return data
|
1085 |
|
|
|
17 |
|
18 |
LOG = logging.getLogger("axolotl.utils.config.models.input")
|
19 |
|
20 |
+
SUPPORTED_METRICS = {"sacrebleu", "comet", "ter", "chrf", "perplexity"}
|
21 |
+
|
22 |
|
23 |
class DeprecatedParameters(BaseModel):
|
24 |
"""configurations that are deprecated"""
|
|
|
178 |
gemma = "gemma" # pylint: disable=invalid-name
|
179 |
cohere = "cohere" # pylint: disable=invalid-name
|
180 |
llama3 = "llama3" # pylint: disable=invalid-name
|
181 |
+
phi_3 = "phi_3" # pylint: disable=invalid-name
|
182 |
|
183 |
|
184 |
class LoftQConfig(BaseModel):
|
|
|
1076 |
)
|
1077 |
|
1078 |
if data.get("eval_causal_lm_metrics"):
|
|
|
1079 |
if not isinstance(data.get("eval_causal_lm_metrics"), list):
|
1080 |
raise ValueError("eval_causal_lm_metrics must be a list")
|
1081 |
# only ["sacrebleu", "comet", "ter", "chrf"] supported
|
1082 |
+
if set(data.get("eval_causal_lm_metrics")) - SUPPORTED_METRICS:
|
1083 |
raise ValueError(
|
1084 |
+
f"eval_causal_lm_metrics must be one of {SUPPORTED_METRICS}"
|
1085 |
)
|
1086 |
return data
|
1087 |
|
src/axolotl/utils/data/sft.py
CHANGED
@@ -474,12 +474,16 @@ def load_prepare_datasets(
|
|
474 |
index=cfg.dataset_shard_idx,
|
475 |
)
|
476 |
|
477 |
-
|
|
|
|
|
|
|
|
|
478 |
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
479 |
to_hash_train = (
|
480 |
dataset._fingerprint # pylint: disable=protected-access
|
481 |
+ "|"
|
482 |
-
+ str(
|
483 |
+ "|"
|
484 |
+ "train"
|
485 |
+ "|"
|
@@ -488,7 +492,7 @@ def load_prepare_datasets(
|
|
488 |
to_hash_test = (
|
489 |
dataset._fingerprint # pylint: disable=protected-access
|
490 |
+ "|"
|
491 |
-
+ str(
|
492 |
+ "|"
|
493 |
+ "test"
|
494 |
+ "|"
|
@@ -498,9 +502,7 @@ def load_prepare_datasets(
|
|
498 |
test_fingerprint = md5(to_hash_test)
|
499 |
|
500 |
dataset = dataset.train_test_split(
|
501 |
-
test_size=
|
502 |
-
if cfg.val_set_size == int(cfg.val_set_size)
|
503 |
-
else cfg.val_set_size,
|
504 |
shuffle=False,
|
505 |
seed=cfg.seed or 42,
|
506 |
train_new_fingerprint=train_fingerprint,
|
@@ -535,6 +537,10 @@ def get_dataset_wrapper(
|
|
535 |
"keep_in_memory": cfg.dataset_keep_in_memory is True,
|
536 |
}
|
537 |
|
|
|
|
|
|
|
|
|
538 |
if (
|
539 |
isinstance(dataset, Dataset)
|
540 |
and "input_ids" in dataset.features
|
|
|
474 |
index=cfg.dataset_shard_idx,
|
475 |
)
|
476 |
|
477 |
+
val_set_size = (
|
478 |
+
int(cfg.val_set_size) if cfg.val_set_size > 1 else float(cfg.val_set_size)
|
479 |
+
)
|
480 |
+
|
481 |
+
if split == "train" and val_set_size:
|
482 |
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
483 |
to_hash_train = (
|
484 |
dataset._fingerprint # pylint: disable=protected-access
|
485 |
+ "|"
|
486 |
+
+ str(val_set_size)
|
487 |
+ "|"
|
488 |
+ "train"
|
489 |
+ "|"
|
|
|
492 |
to_hash_test = (
|
493 |
dataset._fingerprint # pylint: disable=protected-access
|
494 |
+ "|"
|
495 |
+
+ str(val_set_size)
|
496 |
+ "|"
|
497 |
+ "test"
|
498 |
+ "|"
|
|
|
502 |
test_fingerprint = md5(to_hash_test)
|
503 |
|
504 |
dataset = dataset.train_test_split(
|
505 |
+
test_size=val_set_size,
|
|
|
|
|
506 |
shuffle=False,
|
507 |
seed=cfg.seed or 42,
|
508 |
train_new_fingerprint=train_fingerprint,
|
|
|
537 |
"keep_in_memory": cfg.dataset_keep_in_memory is True,
|
538 |
}
|
539 |
|
540 |
+
LOG.info(
|
541 |
+
f"Loading dataset with base_type: {d_base_type} and prompt_style: {d_prompt_style}"
|
542 |
+
)
|
543 |
+
|
544 |
if (
|
545 |
isinstance(dataset, Dataset)
|
546 |
and "input_ids" in dataset.features
|
tests/test_perplexity.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""unit tests for perplexity eval callback"""
|
2 |
+
# pylint: disable=redefined-outer-name
|
3 |
+
|
4 |
+
from pytest import fixture
|
5 |
+
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
6 |
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
7 |
+
|
8 |
+
from axolotl.utils.callbacks.perplexity import Perplexity
|
9 |
+
|
10 |
+
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
11 |
+
|
12 |
+
|
13 |
+
@fixture()
|
14 |
+
def metric(tokenizer):
|
15 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
16 |
+
|
17 |
+
return Perplexity(model, tokenizer, 512)
|
18 |
+
|
19 |
+
|
20 |
+
@fixture()
|
21 |
+
def tokenizer():
|
22 |
+
return AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
|
23 |
+
|
24 |
+
|
25 |
+
def test_perplexity_longer_than_stride(metric):
|
26 |
+
# taken from https://huggingface.co/datasets/roneneldan/TinyStories
|
27 |
+
sample_text = """
|
28 |
+
Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong. One day, Beep was driving in the park when he saw a big tree. The tree had many leaves that were falling. Beep liked how the leaves fall and wanted to play with them. Beep drove under the tree and watched the leaves fall on him. He laughed and beeped his horn. Beep played with the falling leaves all day. When it was time to go home, Beep knew he needed more fuel. He went to the fuel place and got more healthy fuel. Now, Beep was ready to go fast and play again the next day. And Beep lived happily ever after.
|
29 |
+
One day, a little fish named Fin was swimming near the shore. He saw a big crab and wanted to be friends. "Hi, I am Fin. Do you want to play?" asked the little fish. The crab looked at Fin and said, "No, I don't want to play. I am cold and I don't feel fine." Fin felt sad but wanted to help the crab feel better. He swam away and thought of a plan. He remembered that the sun could make things warm. So, Fin swam to the top of the water and called to the sun, "Please, sun, help my new friend feel fine and not freeze!" The sun heard Fin's call and shone its warm light on the shore. The crab started to feel better and not so cold. He saw Fin and said, "Thank you, little fish, for making me feel fine. I don't feel like I will freeze now. Let's play together!" And so, Fin and the crab played and became good friends.
|
30 |
+
"""
|
31 |
+
result = metric.compute([sample_text])
|
32 |
+
ppl = result["score"]
|
33 |
+
assert round(ppl, 2) == 5.37
|
34 |
+
|
35 |
+
|
36 |
+
def test_perplexity_short(metric):
|
37 |
+
# taken from https://huggingface.co/datasets/roneneldan/TinyStories
|
38 |
+
sample_text = "Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun."
|
39 |
+
result = metric.compute([sample_text])
|
40 |
+
ppl = result["score"]
|
41 |
+
assert round(ppl, 2) == 10.02
|