debisoft commited on
Commit
a34933e
1 Parent(s): 6927105

1st commit.

Browse files
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import openai
2
+ import os
3
+ import re
4
+ from datetime import datetime
5
+ import gradio as gr
6
+ import json
7
+ from dotenv import load_dotenv, find_dotenv
8
+ _ = load_dotenv(find_dotenv())
9
+
10
+ from training.consts import DEFAULT_INPUT_MODEL, SUGGESTED_INPUT_MODELS
11
+ from training.trainer import load_training_dataset, load_tokenizer
12
+ from training.generate import generate_response, load_model_tokenizer_for_generate
13
+
14
+ gpu_family = "a100"
15
+
16
+ model_dir = "model"
17
+
18
+ model, tokenizer = load_model_tokenizer_for_generate(model_dir)
19
+
20
+ def get_completion(prompt, model="dolly-v0-70m"):
21
+ messages = [{"role": "user", "content": prompt}]
22
+ response = openai.ChatCompletion.create(
23
+ model=model,
24
+ messages=messages,
25
+ temperature=0, # this is the degree of randomness of the model's output
26
+ )
27
+
28
+ # Examples from https://www.databricks.com/blog/2023/03/24/hello-dolly-democratizing-magic-chatgpt-open-models.html
29
+ instructions = [
30
+ prompt
31
+ ]
32
+
33
+ # set some additional pipeline args
34
+ pipeline_kwargs = {'torch_dtype': "auto"}
35
+ #if gpu_family == "v100":
36
+ #pipeline_kwargs['torch_dtype'] = "float16"
37
+ #elif gpu_family == "a10" or gpu_family == "a100":
38
+ #pipeline_kwargs['torch_dtype'] = "bfloat16"
39
+
40
+ pipeline_kwargs['max_new_tokens'] = 300
41
+
42
+ # Use the model to generate responses for each of the instructions above.
43
+ for instruction in instructions:
44
+ response = generate_response(instruction, model=model, tokenizer=tokenizer, **pipeline_kwargs)
45
+ if response:
46
+ print(f"Instruction: {instruction}\n\n{response}\n\n-----------\n")
47
+
48
+ return response
49
+
50
+ def greet(input):
51
+ prompt = f"""
52
+ Text: ```{input}```
53
+ """
54
+ response = get_completion(prompt)
55
+ return response
56
+
57
+ #iface = gr.Interface(fn=greet, inputs="text", outputs="text")
58
+ #iface.launch()
59
+
60
+ #iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Text to find entities", lines=2)], outputs=[gr.HighlightedText(label="Text with entities")], title="NER with dslim/bert-base-NER", description="Find entities using the `dslim/bert-base-NER` model under the hood!", allow_flagging="never", examples=["My name is Andrew and I live in California", "My name is Poli and work at HuggingFace"])
61
+ iface = gr.Interface(fn=greet, inputs=[gr.Textbox(label="Prompt")], outputs="text")
62
+ iface.launch()
model/config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "EleutherAI/pythia-70m",
3
+ "architectures": [
4
+ "GPTNeoXForCausalLM"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": 0.1,
9
+ "eos_token_id": 0,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout": 0.0,
12
+ "hidden_size": 512,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 2048,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 2048,
17
+ "model_type": "gpt_neox",
18
+ "num_attention_heads": 8,
19
+ "num_hidden_layers": 6,
20
+ "rope_scaling": null,
21
+ "rotary_emb_base": 10000,
22
+ "rotary_pct": 0.25,
23
+ "tie_word_embeddings": false,
24
+ "torch_dtype": "bfloat16",
25
+ "transformers_version": "4.31.0",
26
+ "use_cache": false,
27
+ "use_parallel_residual": true,
28
+ "vocab_size": 50280
29
+ }
model/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 0,
4
+ "eos_token_id": 0,
5
+ "transformers_version": "4.31.0",
6
+ "use_cache": false
7
+ }
model/special_tokens_map.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "additional_special_tokens": [
3
+ "### End",
4
+ "### Instruction:",
5
+ "### Response:\n"
6
+ ],
7
+ "bos_token": "<|endoftext|>",
8
+ "eos_token": "<|endoftext|>",
9
+ "pad_token": "<|endoftext|>",
10
+ "unk_token": "<|endoftext|>"
11
+ }
model/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
model/tokenizer_config.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": "<|endoftext|>",
4
+ "clean_up_tokenization_spaces": true,
5
+ "eos_token": "<|endoftext|>",
6
+ "model_max_length": 1000000000000000019884624838656,
7
+ "tokenizer_class": "GPTNeoXTokenizer",
8
+ "unk_token": "<|endoftext|>"
9
+ }
model/training_args.bin ADDED
Binary file (5.88 kB). View file
 
training/__init__.py ADDED
File without changes
training/consts.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DEFAULT_INPUT_MODEL = "EleutherAI/pythia-6.9b"
2
+ SUGGESTED_INPUT_MODELS = [
3
+ "EleutherAI/pythia-2.8b",
4
+ "EleutherAI/pythia-6.9b",
5
+ "EleutherAI/pythia-12b",
6
+ "EleutherAI/gpt-j-6B",
7
+ "databricks/dolly-v2-3b",
8
+ "databricks/dolly-v2-7b",
9
+ "databricks/dolly-v2-12b"
10
+ ]
11
+ DEFAULT_TRAINING_DATASET = "databricks/databricks-dolly-15k"
12
+ INTRO_BLURB = (
13
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request."
14
+ )
15
+ INSTRUCTION_KEY = "### Instruction:"
16
+ INPUT_KEY = "Input:"
17
+ RESPONSE_KEY = "### Response:"
18
+ END_KEY = "### End"
19
+ RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n"
20
+ DEFAULT_SEED = 42
21
+
22
+ # This is a training prompt that does not contain an input string. The instruction by itself has enough information
23
+ # to respond. For example, the instruction might ask for the year a historic figure was born.
24
+ PROMPT_NO_INPUT_FORMAT = """{intro}
25
+
26
+ {instruction_key}
27
+ {instruction}
28
+
29
+ {response_key}
30
+ {response}
31
+
32
+ {end_key}""".format(
33
+ intro=INTRO_BLURB,
34
+ instruction_key=INSTRUCTION_KEY,
35
+ instruction="{instruction}",
36
+ response_key=RESPONSE_KEY,
37
+ response="{response}",
38
+ end_key=END_KEY,
39
+ )
40
+
41
+ # This is a training prompt that contains an input string that serves as context for the instruction. For example,
42
+ # the input might be a passage from Wikipedia and the intruction is to extract some information from it.
43
+ PROMPT_WITH_INPUT_FORMAT = """{intro}
44
+
45
+ {instruction_key}
46
+ {instruction}
47
+
48
+ {input_key}
49
+ {input}
50
+
51
+ {response_key}
52
+ {response}
53
+
54
+ {end_key}""".format(
55
+ intro=INTRO_BLURB,
56
+ instruction_key=INSTRUCTION_KEY,
57
+ instruction="{instruction}",
58
+ input_key=INPUT_KEY,
59
+ input="{input}",
60
+ response_key=RESPONSE_KEY,
61
+ response="{response}",
62
+ end_key=END_KEY,
63
+ )
64
+
65
+ # This is the prompt that is used for generating responses using an already trained model. It ends with the response
66
+ # key, where the job of the model is to provide the completion that follows it (i.e. the response itself).
67
+ PROMPT_FOR_GENERATION_FORMAT = """{intro}
68
+
69
+ {instruction_key}
70
+ {instruction}
71
+
72
+ {response_key}
73
+ """.format(
74
+ intro=INTRO_BLURB,
75
+ instruction_key=INSTRUCTION_KEY,
76
+ instruction="{instruction}",
77
+ response_key=RESPONSE_KEY,
78
+ )
training/generate.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import re
3
+ from typing import List, Tuple
4
+ import torch
5
+
6
+ import numpy as np
7
+ from transformers import (
8
+ AutoModelForCausalLM,
9
+ AutoTokenizer,
10
+ Pipeline,
11
+ PreTrainedModel,
12
+ PreTrainedTokenizer,
13
+ )
14
+
15
+ from transformers.utils import is_tf_available
16
+
17
+ if is_tf_available():
18
+ import tensorflow as tf
19
+
20
+ from .consts import END_KEY, PROMPT_FOR_GENERATION_FORMAT, RESPONSE_KEY
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ def load_model_tokenizer_for_generate(
26
+ pretrained_model_name_or_path: str,
27
+ ) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
28
+ """Loads the model and tokenizer so that it can be used for generating responses.
29
+
30
+ Args:
31
+ pretrained_model_name_or_path (str): name or path for model
32
+
33
+ Returns:
34
+ Tuple[PreTrainedModel, PreTrainedTokenizer]: model and tokenizer
35
+ """
36
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path, padding_side="left")
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ pretrained_model_name_or_path, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True
39
+ )
40
+ return model, tokenizer
41
+
42
+
43
+ def get_special_token_id(tokenizer: PreTrainedTokenizer, key: str) -> int:
44
+ """Gets the token ID for a given string that has been added to the tokenizer as a special token.
45
+
46
+ When training, we configure the tokenizer so that the sequences like "### Instruction:" and "### End" are
47
+ treated specially and converted to a single, new token. This retrieves the token ID each of these keys map to.
48
+
49
+ Args:
50
+ tokenizer (PreTrainedTokenizer): the tokenizer
51
+ key (str): the key to convert to a single token
52
+
53
+ Raises:
54
+ ValueError: if more than one ID was generated
55
+
56
+ Returns:
57
+ int: the token ID for the given key
58
+ """
59
+ token_ids = tokenizer.encode(key)
60
+ if len(token_ids) > 1:
61
+ raise ValueError(f"Expected only a single token for '{key}' but found {token_ids}")
62
+ return token_ids[0]
63
+
64
+
65
+ class InstructionTextGenerationPipeline(Pipeline):
66
+ def __init__(
67
+ self, *args, do_sample: bool = True, max_new_tokens: int = 256, top_p: float = 0.92, top_k: int = 0, **kwargs
68
+ ):
69
+ """Initialize the pipeline
70
+
71
+ Args:
72
+ do_sample (bool, optional): Whether or not to use sampling. Defaults to True.
73
+ max_new_tokens (int, optional): Max new tokens after the prompt to generate. Defaults to 256.
74
+ top_p (float, optional): If set to float < 1, only the smallest set of most probable tokens with
75
+ probabilities that add up to top_p or higher are kept for generation. Defaults to 0.92.
76
+ top_k (int, optional): The number of highest probability vocabulary tokens to keep for top-k-filtering.
77
+ Defaults to 0.
78
+ """
79
+ super().__init__(*args, do_sample=do_sample, max_new_tokens=max_new_tokens, top_p=top_p, top_k=top_k,
80
+ **kwargs)
81
+
82
+ def _sanitize_parameters(self,
83
+ return_full_text: bool = None,
84
+ **generate_kwargs):
85
+ preprocess_params = {}
86
+
87
+ # newer versions of the tokenizer configure the response key as a special token. newer versions still may
88
+ # append a newline to yield a single token. find whatever token is configured for the response key.
89
+ tokenizer_response_key = next(
90
+ (token for token in self.tokenizer.additional_special_tokens if token.startswith(RESPONSE_KEY)), None
91
+ )
92
+
93
+ response_key_token_id = None
94
+ end_key_token_id = None
95
+ if tokenizer_response_key:
96
+ try:
97
+ response_key_token_id = get_special_token_id(self.tokenizer, tokenizer_response_key)
98
+ end_key_token_id = get_special_token_id(self.tokenizer, END_KEY)
99
+
100
+ # Ensure generation stops once it generates "### End"
101
+ generate_kwargs["eos_token_id"] = end_key_token_id
102
+ except ValueError:
103
+ pass
104
+
105
+ forward_params = generate_kwargs
106
+ postprocess_params = {
107
+ "response_key_token_id": response_key_token_id,
108
+ "end_key_token_id": end_key_token_id
109
+ }
110
+
111
+ if return_full_text is not None:
112
+ postprocess_params["return_full_text"] = return_full_text
113
+
114
+ return preprocess_params, forward_params, postprocess_params
115
+
116
+ def preprocess(self, instruction_text, **generate_kwargs):
117
+ prompt_text = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction_text)
118
+ inputs = self.tokenizer(
119
+ prompt_text,
120
+ return_tensors="pt",
121
+ )
122
+ inputs["prompt_text"] = prompt_text
123
+ inputs["instruction_text"] = instruction_text
124
+ return inputs
125
+
126
+ def _forward(self, model_inputs, **generate_kwargs):
127
+ input_ids = model_inputs["input_ids"]
128
+ attention_mask = model_inputs.get("attention_mask", None)
129
+
130
+ if input_ids.shape[1] == 0:
131
+ input_ids = None
132
+ attention_mask = None
133
+ in_b = 1
134
+ else:
135
+ in_b = input_ids.shape[0]
136
+
137
+ generated_sequence = self.model.generate(
138
+ input_ids=input_ids.to(self.model.device),
139
+ attention_mask=attention_mask.to(self.model.device),
140
+ pad_token_id=self.tokenizer.pad_token_id,
141
+ **generate_kwargs,
142
+ )
143
+
144
+ out_b = generated_sequence.shape[0]
145
+ if self.framework == "pt":
146
+ generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
147
+ elif self.framework == "tf":
148
+ generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
149
+
150
+ instruction_text = model_inputs.pop("instruction_text")
151
+ return {"generated_sequence": generated_sequence, "input_ids": input_ids, "instruction_text": instruction_text}
152
+
153
+ def postprocess(self, model_outputs, response_key_token_id, end_key_token_id, return_full_text: bool = False):
154
+
155
+ generated_sequence = model_outputs["generated_sequence"][0]
156
+ instruction_text = model_outputs["instruction_text"]
157
+
158
+ generated_sequence: List[List[int]] = generated_sequence.numpy().tolist()
159
+ records = []
160
+ for sequence in generated_sequence:
161
+
162
+ # The response will be set to this variable if we can identify it.
163
+ decoded = None
164
+
165
+ # If we have token IDs for the response and end, then we can find the tokens and only decode between them.
166
+ if response_key_token_id and end_key_token_id:
167
+ # Find where "### Response:" is first found in the generated tokens. Considering this is part of the
168
+ # prompt, we should definitely find it. We will return the tokens found after this token.
169
+ try:
170
+ response_pos = sequence.index(response_key_token_id)
171
+ except ValueError:
172
+ logger.warn(f"Could not find response key {response_key_token_id} in: {sequence}")
173
+ response_pos = None
174
+
175
+ if response_pos:
176
+ # Next find where "### End" is located. The model has been trained to end its responses with this
177
+ # sequence (or actually, the token ID it maps to, since it is a special token). We may not find
178
+ # this token, as the response could be truncated. If we don't find it then just return everything
179
+ # to the end. Note that even though we set eos_token_id, we still see the this token at the end.
180
+ try:
181
+ end_pos = sequence.index(end_key_token_id)
182
+ except ValueError:
183
+ end_pos = None
184
+
185
+ decoded = self.tokenizer.decode(sequence[response_pos + 1 : end_pos]).strip()
186
+
187
+ if not decoded:
188
+ # Otherwise we'll decode everything and use a regex to find the response and end.
189
+
190
+ fully_decoded = self.tokenizer.decode(sequence)
191
+
192
+ # The response appears after "### Response:". The model has been trained to append "### End" at the
193
+ # end.
194
+ m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", fully_decoded, flags=re.DOTALL)
195
+
196
+ if m:
197
+ decoded = m.group(1).strip()
198
+ else:
199
+ # The model might not generate the "### End" sequence before reaching the max tokens. In this case,
200
+ # return everything after "### Response:".
201
+ m = re.search(r"#+\s*Response:\s*(.+)", fully_decoded, flags=re.DOTALL)
202
+ if m:
203
+ decoded = m.group(1).strip()
204
+ else:
205
+ logger.warn(f"Failed to find response in:\n{fully_decoded}")
206
+
207
+ # If the full text is requested, then append the decoded text to the original instruction.
208
+ # This technically isn't the full text, as we format the instruction in the prompt the model has been
209
+ # trained on, but to the client it will appear to be the full text.
210
+ if return_full_text:
211
+ decoded = f"{instruction_text}\n{decoded}"
212
+
213
+ rec = {"generated_text": decoded}
214
+
215
+ records.append(rec)
216
+
217
+ return records
218
+
219
+
220
+ def generate_response(
221
+ instruction: str,
222
+ *,
223
+ model: PreTrainedModel,
224
+ tokenizer: PreTrainedTokenizer,
225
+ **kwargs,
226
+ ) -> str:
227
+ """Given an instruction, uses the model and tokenizer to generate a response. This formats the instruction in
228
+ the instruction format that the model was fine-tuned on.
229
+
230
+ Args:
231
+ instruction (str): _description_
232
+ model (PreTrainedModel): the model to use
233
+ tokenizer (PreTrainedTokenizer): the tokenizer to use
234
+
235
+ Returns:
236
+ str: response
237
+ """
238
+
239
+ generation_pipeline = InstructionTextGenerationPipeline(model=model, tokenizer=tokenizer, **kwargs)
240
+ return generation_pipeline(instruction)[0]["generated_text"]
training/trainer.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Databricks, Inc.
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ from functools import partial
17
+ from pathlib import Path
18
+ from typing import Any, Dict, List, Tuple, Union
19
+
20
+ import click
21
+ import numpy as np
22
+ from datasets import Dataset, load_dataset
23
+ from transformers import (
24
+ AutoModelForCausalLM,
25
+ AutoTokenizer,
26
+ DataCollatorForLanguageModeling,
27
+ PreTrainedTokenizer,
28
+ Trainer,
29
+ TrainingArguments,
30
+ set_seed,
31
+ )
32
+
33
+ from .consts import (
34
+ DEFAULT_INPUT_MODEL,
35
+ DEFAULT_SEED,
36
+ PROMPT_WITH_INPUT_FORMAT,
37
+ PROMPT_NO_INPUT_FORMAT,
38
+ END_KEY,
39
+ INSTRUCTION_KEY,
40
+ RESPONSE_KEY_NL,
41
+ DEFAULT_TRAINING_DATASET,
42
+ )
43
+
44
+ logger = logging.getLogger(__name__)
45
+ ROOT_PATH = Path(__file__).parent.parent
46
+
47
+
48
+ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
49
+ def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
50
+ batch = super().torch_call(examples)
51
+
52
+ # The prompt ends with the response key plus a newline. We encode this and then try to find it in the
53
+ # sequence of tokens. This should just be a single token.
54
+ response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL)
55
+
56
+ labels = batch["labels"].clone()
57
+
58
+ for i in range(len(examples)):
59
+
60
+ response_token_ids_start_idx = None
61
+ for idx in np.where(batch["labels"][i] == response_token_ids[0])[0]:
62
+ response_token_ids_start_idx = idx
63
+ break
64
+
65
+ if response_token_ids_start_idx is None:
66
+ raise RuntimeError(
67
+ f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}'
68
+ )
69
+
70
+ response_token_ids_end_idx = response_token_ids_start_idx + 1
71
+
72
+ # Make pytorch loss function ignore all tokens up through the end of the response key
73
+ labels[i, :response_token_ids_end_idx] = -100
74
+
75
+ batch["labels"] = labels
76
+
77
+ return batch
78
+
79
+
80
+ def preprocess_batch(batch: Dict[str, List], tokenizer: AutoTokenizer, max_length: int) -> dict:
81
+ return tokenizer(
82
+ batch["text"],
83
+ max_length=max_length,
84
+ truncation=True,
85
+ )
86
+
87
+
88
+ def load_training_dataset(path_or_dataset: str = DEFAULT_TRAINING_DATASET) -> Dataset:
89
+ logger.info(f"Loading dataset from {path_or_dataset}")
90
+ dataset = load_dataset(path_or_dataset)["train"]
91
+ logger.info("Found %d rows", dataset.num_rows)
92
+
93
+ def _add_text(rec):
94
+ instruction = rec["instruction"]
95
+ response = rec["response"]
96
+ context = rec.get("context")
97
+
98
+ if not instruction:
99
+ raise ValueError(f"Expected an instruction in: {rec}")
100
+
101
+ if not response:
102
+ raise ValueError(f"Expected a response in: {rec}")
103
+
104
+ # For some instructions there is an input that goes along with the instruction, providing context for the
105
+ # instruction. For example, the input might be a passage from Wikipedia and the instruction says to extract
106
+ # some piece of information from it. The response is that information to extract. In other cases there is
107
+ # no input. For example, the instruction might be open QA such as asking what year some historic figure was
108
+ # born.
109
+ if context:
110
+ rec["text"] = PROMPT_WITH_INPUT_FORMAT.format(instruction=instruction, response=response, input=context)
111
+ else:
112
+ rec["text"] = PROMPT_NO_INPUT_FORMAT.format(instruction=instruction, response=response)
113
+ return rec
114
+
115
+ dataset = dataset.map(_add_text)
116
+
117
+ return dataset
118
+
119
+
120
+ def load_tokenizer(pretrained_model_name_or_path: str = DEFAULT_INPUT_MODEL) -> PreTrainedTokenizer:
121
+ logger.info(f"Loading tokenizer for {pretrained_model_name_or_path}")
122
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
123
+ tokenizer.pad_token = tokenizer.eos_token
124
+ tokenizer.add_special_tokens({"additional_special_tokens": [END_KEY, INSTRUCTION_KEY, RESPONSE_KEY_NL]})
125
+ return tokenizer
126
+
127
+
128
+ def load_model(
129
+ pretrained_model_name_or_path: str = DEFAULT_INPUT_MODEL, *, gradient_checkpointing: bool = False
130
+ ) -> AutoModelForCausalLM:
131
+ logger.info(f"Loading model for {pretrained_model_name_or_path}")
132
+ model = AutoModelForCausalLM.from_pretrained(
133
+ pretrained_model_name_or_path, trust_remote_code=True, use_cache=False if gradient_checkpointing else True
134
+ )
135
+ return model
136
+
137
+
138
+ def get_model_tokenizer(
139
+ pretrained_model_name_or_path: str = DEFAULT_INPUT_MODEL, *, gradient_checkpointing: bool = False
140
+ ) -> Tuple[AutoModelForCausalLM, PreTrainedTokenizer]:
141
+ tokenizer = load_tokenizer(pretrained_model_name_or_path)
142
+ model = load_model(pretrained_model_name_or_path, gradient_checkpointing=gradient_checkpointing)
143
+ model.resize_token_embeddings(len(tokenizer))
144
+
145
+ return model, tokenizer
146
+
147
+
148
+ def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed=DEFAULT_SEED, training_dataset: str = DEFAULT_TRAINING_DATASET) -> Dataset:
149
+ """Loads the training dataset and tokenizes it so it is ready for training.
150
+
151
+ Args:
152
+ tokenizer (AutoTokenizer): Tokenizer tied to the model.
153
+ max_length (int): Maximum number of tokens to emit from tokenizer.
154
+
155
+ Returns:
156
+ Dataset: HuggingFace dataset
157
+ """
158
+
159
+ dataset = load_training_dataset(training_dataset)
160
+
161
+ logger.info("Preprocessing dataset")
162
+ _preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)
163
+ dataset = dataset.map(
164
+ _preprocessing_function,
165
+ batched=True,
166
+ remove_columns=["instruction", "context", "response", "text", "category"],
167
+ )
168
+
169
+ # Make sure we don't have any truncated records, as this would mean the end keyword is missing.
170
+ logger.info("Processed dataset has %d rows", dataset.num_rows)
171
+ dataset = dataset.filter(lambda rec: len(rec["input_ids"]) < max_length)
172
+ logger.info("Processed dataset has %d rows after filtering for truncated records", dataset.num_rows)
173
+
174
+ logger.info("Shuffling dataset")
175
+ dataset = dataset.shuffle(seed=seed)
176
+
177
+ logger.info("Done preprocessing")
178
+
179
+ return dataset
180
+
181
+
182
+ def train(
183
+ *,
184
+ input_model: str,
185
+ local_output_dir: str,
186
+ dbfs_output_dir: str,
187
+ epochs: int,
188
+ per_device_train_batch_size: int,
189
+ per_device_eval_batch_size: int,
190
+ lr: float,
191
+ seed: int,
192
+ deepspeed: str,
193
+ gradient_checkpointing: bool,
194
+ local_rank: str,
195
+ bf16: bool,
196
+ logging_steps: int,
197
+ save_steps: int,
198
+ eval_steps: int,
199
+ test_size: Union[float, int],
200
+ save_total_limit: int,
201
+ warmup_steps: int,
202
+ training_dataset: str = DEFAULT_TRAINING_DATASET,
203
+ ):
204
+ set_seed(seed)
205
+
206
+ model, tokenizer = get_model_tokenizer(
207
+ pretrained_model_name_or_path=input_model, gradient_checkpointing=gradient_checkpointing
208
+ )
209
+
210
+ # Use the same max length that the model supports. Fall back to 1024 if the setting can't be found.
211
+ # The configuraton for the length can be stored under different names depending on the model. Here we attempt
212
+ # a few possible names we've encountered.
213
+ conf = model.config
214
+ max_length = None
215
+ for length_setting in ["n_positions", "max_position_embeddings", "seq_length"]:
216
+ max_length = getattr(model.config, length_setting, None)
217
+ if max_length:
218
+ logger.info(f"Found max lenth: {max_length}")
219
+ break
220
+ if not max_length:
221
+ max_length = 1024
222
+ logger.info(f"Using default max length: {max_length}")
223
+
224
+ processed_dataset = preprocess_dataset(tokenizer=tokenizer, max_length=max_length, seed=seed, training_dataset=training_dataset)
225
+
226
+ split_dataset = processed_dataset.train_test_split(test_size=test_size, seed=seed)
227
+
228
+ logger.info("Train data size: %d", split_dataset["train"].num_rows)
229
+ logger.info("Test data size: %d", split_dataset["test"].num_rows)
230
+
231
+ data_collator = DataCollatorForCompletionOnlyLM(
232
+ tokenizer=tokenizer, mlm=False, return_tensors="pt", pad_to_multiple_of=8
233
+ )
234
+
235
+ # enable fp16 if not bf16
236
+ fp16 = not bf16
237
+
238
+ if not dbfs_output_dir:
239
+ logger.warn("Will NOT save to DBFS")
240
+
241
+ training_args = TrainingArguments(
242
+ output_dir=local_output_dir,
243
+ per_device_train_batch_size=per_device_train_batch_size,
244
+ per_device_eval_batch_size=per_device_eval_batch_size,
245
+ fp16=fp16,
246
+ bf16=bf16,
247
+ learning_rate=lr,
248
+ num_train_epochs=epochs,
249
+ deepspeed=deepspeed,
250
+ gradient_checkpointing=gradient_checkpointing,
251
+ logging_dir=f"{local_output_dir}/runs",
252
+ logging_strategy="steps",
253
+ logging_steps=logging_steps,
254
+ evaluation_strategy="steps",
255
+ eval_steps=eval_steps,
256
+ save_strategy="steps",
257
+ save_steps=save_steps,
258
+ save_total_limit=save_total_limit,
259
+ load_best_model_at_end=False,
260
+ report_to="tensorboard",
261
+ disable_tqdm=True,
262
+ remove_unused_columns=False,
263
+ local_rank=local_rank,
264
+ warmup_steps=warmup_steps,
265
+ )
266
+
267
+ logger.info("Instantiating Trainer")
268
+
269
+ trainer = Trainer(
270
+ model=model,
271
+ tokenizer=tokenizer,
272
+ args=training_args,
273
+ train_dataset=split_dataset["train"],
274
+ eval_dataset=split_dataset["test"],
275
+ data_collator=data_collator,
276
+ )
277
+
278
+ logger.info("Training")
279
+ trainer.train()
280
+
281
+ logger.info(f"Saving Model to {local_output_dir}")
282
+ trainer.save_model(output_dir=local_output_dir)
283
+
284
+ if dbfs_output_dir:
285
+ logger.info(f"Saving Model to {dbfs_output_dir}")
286
+ trainer.save_model(output_dir=dbfs_output_dir)
287
+
288
+ logger.info("Done.")
289
+
290
+
291
+ @click.command()
292
+ @click.option("--input-model", type=str, help="Input model to fine tune", default=DEFAULT_INPUT_MODEL)
293
+ @click.option("--local-output-dir", type=str, help="Write directly to this local path", required=True)
294
+ @click.option("--dbfs-output-dir", type=str, help="Sync data to this path on DBFS")
295
+ @click.option("--epochs", type=int, default=3, help="Number of epochs to train for.")
296
+ @click.option("--per-device-train-batch-size", type=int, default=8, help="Batch size to use for training.")
297
+ @click.option("--per-device-eval-batch-size", type=int, default=8, help="Batch size to use for evaluation.")
298
+ @click.option(
299
+ "--test-size", type=int, default=1000, help="Number of test records for evaluation, or ratio of test records."
300
+ )
301
+ @click.option("--warmup-steps", type=int, default=None, help="Number of steps to warm up to learning rate")
302
+ @click.option("--logging-steps", type=int, default=10, help="How often to log")
303
+ @click.option("--eval-steps", type=int, default=50, help="How often to run evaluation on test records")
304
+ @click.option("--save-steps", type=int, default=400, help="How often to checkpoint the model")
305
+ @click.option("--save-total-limit", type=int, default=10, help="Maximum number of checkpoints to keep on disk")
306
+ @click.option("--lr", type=float, default=1e-5, help="Learning rate to use for training.")
307
+ @click.option("--seed", type=int, default=DEFAULT_SEED, help="Seed to use for training.")
308
+ @click.option("--deepspeed", type=str, default=None, help="Path to deepspeed config file.")
309
+ @click.option("--training-dataset", type=str, default=DEFAULT_TRAINING_DATASET, help="Path to dataset for training")
310
+ @click.option(
311
+ "--gradient-checkpointing/--no-gradient-checkpointing",
312
+ is_flag=True,
313
+ default=True,
314
+ help="Use gradient checkpointing?",
315
+ )
316
+ @click.option(
317
+ "--local_rank",
318
+ type=str,
319
+ default=True,
320
+ help="Provided by deepspeed to identify which instance this process is when performing multi-GPU training.",
321
+ )
322
+ @click.option("--bf16", type=bool, default=None, help="Whether to use bf16 (preferred on A100's).")
323
+ def main(**kwargs):
324
+ train(**kwargs)
325
+
326
+
327
+ if __name__ == "__main__":
328
+ logging.basicConfig(
329
+ format="%(asctime)s %(levelname)s [%(name)s] %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
330
+ )
331
+ try:
332
+ main()
333
+ except Exception:
334
+ logger.exception("main failed")
335
+ raise