winglian commited on
Commit
4ac9e25
·
1 Parent(s): 328c3bc

new prompters, misc fixes for output dir missing using fsdp, and changing max seq len

Browse files
scripts/finetune.py CHANGED
@@ -279,6 +279,9 @@ def train(
279
  logging.info(
280
  f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
281
  )
 
 
 
282
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
283
 
284
  logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
 
279
  logging.info(
280
  f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}"
281
  )
282
+
283
+ if not Path(cfg.output_dir).is_dir():
284
+ os.makedirs(cfg.output_dir, exist_ok=True)
285
  trainer.train(resume_from_checkpoint=resume_from_checkpoint)
286
 
287
  logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
src/axolotl/prompt_strategies/alpaca_chat.py CHANGED
@@ -18,6 +18,15 @@ def load(tokenizer, cfg):
18
  )
19
 
20
 
 
 
 
 
 
 
 
 
 
21
  class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
22
  """
23
  Tokenizing strategy for AlpacaQA
@@ -31,6 +40,28 @@ class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
31
  )
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def load_qa(tokenizer, cfg):
35
  return AlpacaQAPromptTokenizingStrategy(
36
  AlpacaPrompter(PromptStyle.CHAT.value),
@@ -38,3 +69,12 @@ def load_qa(tokenizer, cfg):
38
  cfg.train_on_inputs,
39
  cfg.sequence_len,
40
  )
 
 
 
 
 
 
 
 
 
 
18
  )
19
 
20
 
21
+ class AlpacaConcisePrompter(AlpacaPrompter):
22
+ """
23
+ Alpaca Prompter extending the system prompt to ask for concise answers
24
+ """
25
+
26
+ system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that concisely and appropriately completes the request.\n\n"
27
+ system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately and concisely completes the request.\n\n"
28
+
29
+
30
  class AlpacaQAPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
31
  """
32
  Tokenizing strategy for AlpacaQA
 
40
  )
41
 
42
 
43
+ class CamelAIPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
44
+ """
45
+ Tokenizing strategy for CamelAI datasets
46
+ """
47
+
48
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
49
+ return (
50
+ prompt["message_1"],
51
+ "",
52
+ prompt["message_1"],
53
+ )
54
+
55
+
56
+ def load_concise(tokenizer, cfg):
57
+ return AlpacaPromptTokenizingStrategy(
58
+ AlpacaConcisePrompter(PromptStyle.CHAT.value),
59
+ tokenizer,
60
+ cfg.train_on_inputs,
61
+ cfg.sequence_len,
62
+ )
63
+
64
+
65
  def load_qa(tokenizer, cfg):
66
  return AlpacaQAPromptTokenizingStrategy(
67
  AlpacaPrompter(PromptStyle.CHAT.value),
 
69
  cfg.train_on_inputs,
70
  cfg.sequence_len,
71
  )
72
+
73
+
74
+ def load_camel_ai(tokenizer, cfg):
75
+ return CamelAIPromptTokenizingStrategy(
76
+ AlpacaPrompter(PromptStyle.CHAT.value),
77
+ tokenizer,
78
+ cfg.train_on_inputs,
79
+ cfg.sequence_len,
80
+ )
src/axolotl/prompt_strategies/context_qa.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Module containing the classes for Context QA Prompt Tokenization Strategies"""
2
+ from typing import Tuple
3
+
4
+ from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
5
+ from axolotl.prompters import AlpacaPrompter, PromptStyle
6
+
7
+
8
+ # article, unanswerable_question, question, answer
9
+ def load_404(tokenizer, cfg):
10
+ return AlpacaMissingInfoContextPromptTokenizingStrategy(
11
+ AlpacaContextPrompter(PromptStyle.CHAT.value),
12
+ tokenizer,
13
+ cfg.train_on_inputs,
14
+ cfg.sequence_len,
15
+ )
16
+
17
+
18
+ def load(tokenizer, cfg):
19
+ return AlpacaContextPromptTokenizingStrategy(
20
+ AlpacaContextPrompter(PromptStyle.CHAT.value),
21
+ tokenizer,
22
+ cfg.train_on_inputs,
23
+ cfg.sequence_len,
24
+ )
25
+
26
+
27
+ class AlpacaContextPrompter(AlpacaPrompter):
28
+ """
29
+ Customized system prompted for concise QA
30
+ """
31
+
32
+ system_prompt = (
33
+ "Use the following contextual information to concisely answer the question.\n"
34
+ )
35
+ system_no_input_prompt = (
36
+ "Use the following contextual information to concisely answer the question.\n"
37
+ )
38
+
39
+
40
+ class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
41
+ """
42
+ Tokenization Strategy to combine in-context article with a question and answer
43
+ """
44
+
45
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
46
+ return (
47
+ prompt["article"] + "\n===\n" + prompt["question"],
48
+ "",
49
+ prompt["answer"],
50
+ )
51
+
52
+
53
+ class AlpacaMissingInfoContextPromptTokenizingStrategy(
54
+ InstructionPromptTokenizingStrategy
55
+ ):
56
+ """
57
+ Tokenization Strategy to combine in-context article with a question that can't be answered
58
+ from the context and a default response to that effect
59
+ """
60
+
61
+ def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
62
+ return (
63
+ prompt["article"] + "\n===\n" + prompt["unanswerable_question"],
64
+ "",
65
+ "The context provided does not contain any information about your inquiry. "
66
+ "Therefore, I'm unable to answer your question based on the given context.",
67
+ )
src/axolotl/utils/models.py CHANGED
@@ -234,6 +234,10 @@ def load_model(
234
  base_model,
235
  trust_remote_code=cfg.trust_remote_code or False,
236
  )
 
 
 
 
237
  model = AutoModelForCausalLM.from_pretrained(
238
  base_model,
239
  config=config,
 
234
  base_model,
235
  trust_remote_code=cfg.trust_remote_code or False,
236
  )
237
+ # Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
238
+ # when training starts
239
+ if config.max_seq_len and cfg.sequence_len > config.max_seq_len:
240
+ config.max_seq_len = cfg.sequence_len
241
  model = AutoModelForCausalLM.from_pretrained(
242
  base_model,
243
  config=config,