winglian commited on
Commit
81de0ef
1 Parent(s): 34af1b4

add support for alpaca reflect training (#2)

Browse files
configs/vicuna_13B_4bit_reflect.yml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: anon8231489123/vicuna-13b-GPTQ-4bit-128g
2
+ base_model_config: anon8231489123/vicuna-13b-GPTQ-4bit-128g
3
+ model_type: LlamaForCausalLM
4
+ tokenizer_type: LlamaTokenizer
5
+ load_in_8bit: false
6
+ load_4bit: true
7
+ gptq_groupsize: 128
8
+ gptq_model_v1: false
9
+ datasets:
10
+ # https://github.com/vaguenebula/AlpacaDataReflect/blob/main/alpaca_reflect_pruned.json
11
+ - path: data/alpaca_reflect_pruned.jsonl
12
+ type: reflection
13
+ dataset_prepared_path: data/last_run_prepared
14
+ val_set_size: 0.04
15
+ adapter: lora
16
+ lora_model_dir:
17
+ sequence_len: 2048
18
+ max_packed_sequence_len: 2048
19
+ lora_r: 8
20
+ lora_alpha: 16
21
+ lora_dropout: 0.05
22
+ lora_target_modules:
23
+ - q_proj
24
+ - v_proj
25
+ # - k_proj
26
+ # - o_proj
27
+ lora_fan_in_fan_out: false
28
+ wandb_project:
29
+ wandb_watch:
30
+ wandb_run_id:
31
+ wandb_log_model: checkpoint
32
+ output_dir: ./lora-reflect
33
+ batch_size: 8
34
+ micro_batch_size: 2
35
+ num_epochs: 3
36
+ learning_rate: 0.00003
37
+ train_on_inputs: false
38
+ group_by_length: false
39
+ bf16: true
40
+ tf32: true
41
+ gradient_checkpointing: false
42
+ early_stopping_patience: 3
43
+ resume_from_checkpoint:
44
+ local_rank:
45
+ flash_attention: true
scripts/finetune.py CHANGED
@@ -37,9 +37,9 @@ from axolotl.prompt_tokenizers import (
37
  ShareGPTPromptTokenizingStrategy,
38
  LLAMA_DEFAULT_PAD_TOKEN,
39
  GPTeacherPromptTokenizingStrategy,
40
- OpenAssistantPromptTokenizingStrategy,
41
  )
42
- from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
43
 
44
  logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
45
  DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
@@ -395,6 +395,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
395
  )
396
  trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
397
 
 
398
  if cfg.early_stopping_patience:
399
  early_stop_cb = EarlyStoppingCallback(
400
  cfg.early_stopping_patience,
@@ -540,6 +541,15 @@ def train(
540
  )
541
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
542
  datasets.append(ds_wrapper)
 
 
 
 
 
 
 
 
 
543
  elif d.type == "sharegpt":
544
  ds_strategy = ShareGPTPromptTokenizingStrategy(
545
  ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
 
37
  ShareGPTPromptTokenizingStrategy,
38
  LLAMA_DEFAULT_PAD_TOKEN,
39
  GPTeacherPromptTokenizingStrategy,
40
+ OpenAssistantPromptTokenizingStrategy, AlpacaReflectionPTStrategy,
41
  )
42
+ from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter, ReflectAlpacaPrompter
43
 
44
  logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO"))
45
  DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
 
395
  )
396
  trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler)
397
 
398
+ # TODO on_save callback to sync checkpoints to GCP/AWS in background
399
  if cfg.early_stopping_patience:
400
  early_stop_cb = EarlyStoppingCallback(
401
  cfg.early_stopping_patience,
 
541
  )
542
  ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
543
  datasets.append(ds_wrapper)
544
+ elif d.type == "reflection":
545
+ ds_strategy = AlpacaReflectionPTStrategy(
546
+ ReflectAlpacaPrompter(),
547
+ tokenizer,
548
+ cfg.train_on_inputs,
549
+ cfg.sequence_len,
550
+ )
551
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
552
+ datasets.append(ds_wrapper)
553
  elif d.type == "sharegpt":
554
  ds_strategy = ShareGPTPromptTokenizingStrategy(
555
  ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len
src/axolotl/prompt_tokenizers.py CHANGED
@@ -107,6 +107,67 @@ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
107
  )
108
 
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
111
  def tokenize_prompt(self, prompt):
112
  try:
 
107
  )
108
 
109
 
110
+ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
111
+ def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
112
+ raise NotImplementedError
113
+
114
+ def tokenize_prompt(self, prompt):
115
+ instruction, input, output, reflection, corrected = self.parse_instruction_fields(prompt)
116
+ full_prompt = self._build_full_prompt(instruction, input, output, reflection, corrected)
117
+ tokenized_full_prompt = self._tokenize(full_prompt)
118
+ if not self.train_on_inputs:
119
+ user_prompt = self.prompter.build_prompt(
120
+ instruction,
121
+ input,
122
+ )
123
+ tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
124
+ user_prompt_len = len(tokenized_user_prompt["input_ids"])
125
+ # TODO this could be sped up using numpy array slicing
126
+ tokenized_full_prompt["labels"] = [
127
+ -100
128
+ ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
129
+
130
+ return tokenized_full_prompt
131
+
132
+ def _build_full_prompt(self, instruction, input, output, reflection, corrected):
133
+ return self.prompter.build_prompt(
134
+ instruction,
135
+ input,
136
+ output,
137
+ reflection,
138
+ corrected,
139
+ )
140
+
141
+ def _tokenize(self, prompt, add_eos_token=True):
142
+ result = self.tokenizer(
143
+ prompt,
144
+ truncation=True,
145
+ max_length=self.sequence_len,
146
+ padding=False,
147
+ return_tensors=None,
148
+ )
149
+ if (
150
+ result["input_ids"][-1] != self.tokenizer.eos_token_id
151
+ and len(result["input_ids"]) < self.sequence_len
152
+ and add_eos_token
153
+ ):
154
+ result["input_ids"].append(self.tokenizer.eos_token_id)
155
+ result["attention_mask"].append(1)
156
+
157
+ result["labels"] = result["input_ids"].copy()
158
+ return result
159
+
160
+
161
+ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy):
162
+ def parse_instruction_fields(self, prompt) -> (str, str, str, str, str):
163
+ return (
164
+ prompt["instruction"],
165
+ prompt["input"] if "input" in prompt else "",
166
+ prompt["output"],
167
+ prompt["reflection"],
168
+ prompt["corrected"],
169
+ )
170
+
171
  class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
172
  def tokenize_prompt(self, prompt):
173
  try:
src/axolotl/prompters.py CHANGED
@@ -35,6 +35,35 @@ class GPTeacherPrompter(AlpacaPrompter):
35
  ...
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  class SeparatorStyle(Enum):
39
  """Different separator style."""
40
 
 
35
  ...
36
 
37
 
38
+ class ReflectAlpacaPrompter:
39
+ prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
40
+ prompt_no_input = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Response:\n"
41
+ agent_label = "{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}"
42
+ response_split = "### Response:"
43
+
44
+ def build_prompt(
45
+ self,
46
+ instruction: str,
47
+ input: Union[None, str] = None,
48
+ output: Union[None, str] = None,
49
+ reflection: Union[None, str] = None,
50
+ corrected: Union[None, str] = None,
51
+ ) -> str:
52
+ # returns the full prompt from instruction and optional input
53
+ # if a label (=response, =output) is provided, it's also appended.
54
+ if input:
55
+ res = self.prompt_input.format(instruction=instruction, input=input)
56
+ else:
57
+ res = self.prompt_no_input.format(instruction=instruction)
58
+ if output and reflection and corrected:
59
+ label = self.agent_label.format(output=output, reflection=reflection, corrected=corrected)
60
+ res = f"{res}{label}"
61
+ return res
62
+
63
+ def get_response(self, output: str) -> str:
64
+ return output.split(self.response_split)[1].strip()
65
+
66
+
67
  class SeparatorStyle(Enum):
68
  """Different separator style."""
69