hamel commited on
Commit
f1de29d
1 Parent(s): 7fabc4d

Respect sequence_len in config for `type: llama2_chat` (#926)

Browse files

* Respect sequence_len in config for `type: llama2_chat`

It was hardcoded to `4096` I am not sure why? This updates it to pull from the config.

cc:

@winglian



* Update llama2_chat.py

* apply black formatting

* fix tokenizer

* update test data

* lint fixtures

src/axolotl/prompt_strategies/llama2_chat.py CHANGED
@@ -81,8 +81,9 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
81
 
82
  def __init__(self, *args, **kwargs):
83
  super().__init__(*args, **kwargs)
84
- self.sequence_len = 4096
85
- self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
 
86
  # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
87
 
88
  def tokenize_prompt(self, prompt):
 
81
 
82
  def __init__(self, *args, **kwargs):
83
  super().__init__(*args, **kwargs)
84
+ self.tokenizer.add_special_tokens(
85
+ {"pad_token": getattr(self.tokenizer, "pad_token", "<pad>")}
86
+ )
87
  # https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
88
 
89
  def tokenize_prompt(self, prompt):
tests/fixtures/conversation.tokenized_llama2chat.json CHANGED
The diff for this file is too large to render. See raw diff