winglian commited on
Commit
ce24f5e
1 Parent(s): e9da4b9

WIP for axolotl trainer

Browse files
.editorconfig ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ root = true
2
+
3
+ [*]
4
+ end_of_line = lf
5
+ insert_final_newline = true
6
+ trim_trailing_whitespace = true
7
+
8
+ [*.py]
9
+ indent_style = space
10
+ indent_size = 4
11
+
12
+ [**.yml]
13
+ indent_style = space
14
+ indent_size = 2
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ **/axolotl.egg-info
2
+ **/__pycache__
3
+ .idea
README.md CHANGED
@@ -1,6 +1,13 @@
1
  # Axolotl
2
 
3
- ### You know you're going to axolotl questions
4
 
5
 
 
6
 
 
 
 
 
 
 
 
1
  # Axolotl
2
 
3
+ #### You know you're going to axolotl questions
4
 
5
 
6
+ ### Converting JSON data files to JSONL
7
 
8
+ ```shell
9
+ python3 ./scripts/alpaca_json_to_jsonl.py --input data/alpaca_data_gpt4.json > data/alpaca_data_gpt4.jsonl
10
+ python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/vicuna_cleaned.json > data/vicuna_cleaned.jsonl
11
+ python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/roleplay-similarity_0.6-instruct-dataset.json > data/roleplay-similarity_0.6-instruct-dataset.jsonl
12
+ python3 ./scripts/alpaca_json_to_jsonl.py --input data/raw/gpt4-instruct-similarity-0.6-dataset.json > data/gpt4-instruct-similarity-0.6-dataset.jsonl
13
+ ```
configs/pythia_1_2B_alpaca.yml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: EleutherAI/pythia-1.4b-deduped
2
+ model_type: GPTNeoXForCausalLM
3
+ tokenizer_type: AutoTokenizer
4
+ load_in_8bit: true
5
+ datasets:
6
+ - path: ./data/alpaca_data_gpt4.jsonl
7
+ type: alpaca
8
+ - path: ./data/vicuna_cleaned.jsonl
9
+ type: sharegpt
10
+ - path: ./data/gpt4-instruct-similarity-0.6-dataset.jsonl
11
+ type: gpteacher
12
+ - path: ./data/roleplay-similarity_0.6-instruct-dataset.jsonl
13
+ type: gpteacher
14
+ val_set_size: 0.05
15
+ adapter: lora
16
+ sequence_len: 2048
17
+ lora_r: 16
18
+ lora_alpha: 32
19
+ lora_dropout: 0.05
20
+ lora_target_modules:
21
+ - q_proj
22
+ - v_proj
23
+ wandb_project:
24
+ wandb_watch:
25
+ wandb:run_name:
26
+ wandb_log_model: checkpoint
27
+ output_dir: ./lora-alpaca
28
+ batch_size: 128
29
+ micro_batch_size: 8
30
+ num_epochs: 5
31
+ learning_rate: 0.0003
32
+ train_on_inputs: false
33
+ bf16: True
34
+ fp16: True
35
+ resume_from_checkpoint:
36
+ local_rank:
37
+ deepspeed:
data/README.md ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ ```shell
4
+ curl https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_gpt4.json -o raw/alpaca_data_gpt4.json
5
+ curl https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -L -o raw/vicuna_cleaned.json
6
+ curl https://github.com/teknium1/GPTeacher/blob/main/Instruct/gpt4-instruct-similarity-0.6-dataset.json?raw=true -L -o raw/gpt4-instruct-similarity-0.6-dataset.json
7
+ curl https://github.com/teknium1/GPTeacher/blob/main/Roleplay/roleplay-similarity_0.6-instruct-dataset.json?raw=true -L -o raw/roleplay-similarity_0.6-instruct-dataset.json
8
+ ```
data/raw/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ **
pyproject.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools", "wheel"]
3
+ build-backend = "setuptools.build_meta"
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ git+https://github.com/huggingface/transformers.git
2
+ git+https://github.com/huggingface/peft.git
3
+ attrdict
4
+ fire
5
+ PyYAML==6.0
6
+ black
scripts/alpaca_json_to_jsonl.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ import fire
6
+ from typing import Optional
7
+
8
+ # add src to the pythonpath so we don't need to pip install this
9
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
10
+ src_dir = os.path.join(project_root, 'src')
11
+ sys.path.insert(0, src_dir)
12
+
13
+ from axolotl.convert import *
14
+
15
+ def main(
16
+ input: Path,
17
+ output: Optional[Path] = None,
18
+ to_stdout: Optional[bool] = False,
19
+ ):
20
+ file_reader = FileReader()
21
+ if to_stdout or output is None:
22
+ writer = StdoutWriter()
23
+ else:
24
+ writer = FileWriter(output)
25
+ json_parser = JsonParser()
26
+ jsonl_serializer = JsonlSerializer()
27
+
28
+ converter = JsonToJsonlConverter(
29
+ file_reader, writer, json_parser, jsonl_serializer
30
+ )
31
+
32
+ converter.convert(input, output)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ fire.Fire(main)
scripts/finetune.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+
5
+ import fire
6
+ import torch
7
+ import transformers
8
+ import yaml
9
+ from attrdict import AttrDict
10
+ from datasets import load_dataset, IterableDataset
11
+ from peft import (
12
+ LoraConfig,
13
+ get_peft_model,
14
+ prepare_model_for_int8_training,
15
+ )
16
+ from transformers import AutoModelForCausalLM, AutoTokenizer
17
+
18
+ # add src to the pythonpath so we don't need to pip install this
19
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
20
+ src_dir = os.path.join(project_root, 'src')
21
+ sys.path.insert(0, src_dir)
22
+
23
+ from axolotl.datasets import TokenizedPromptDataset
24
+ from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy, ShareGPTPromptTokenizingStrategy, \
25
+ LLAMA_DEFAULT_PAD_TOKEN, GPTeacherPromptTokenizingStrategy
26
+ from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter
27
+
28
+ def setup_wandb_env_vars(cfg):
29
+ if len(cfg.wandb_project) > 0:
30
+ os.environ["WANDB_PROJECT"] = cfg.wandb_project
31
+ cfg.use_wandb = True
32
+ if len(cfg.wandb_watch) > 0:
33
+ os.environ["WANDB_WATCH"] = cfg.wandb_watch
34
+ if len(cfg.wandb_log_model) > 0:
35
+ os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
36
+
37
+
38
+ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"):
39
+ if adapter != "lora":
40
+ raise NotImplementedError(f"{adapter} peft adapter not available")
41
+ try:
42
+ model = getattr(transformers, model_type).from_pretrained(
43
+ base_model,
44
+ load_in_8bit=cfg.load_in_8bit,
45
+ torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
46
+ device_map=cfg.device_map,
47
+ )
48
+ except:
49
+ model = AutoModelForCausalLM.from_pretrained(
50
+ base_model,
51
+ load_in_8bit=cfg.load_in_8bit,
52
+ torch_dtype=torch.float16 if cfg.load_in_8bit else torch.float32,
53
+ device_map=cfg.device_map,
54
+ )
55
+
56
+ try:
57
+ tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model)
58
+ except:
59
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
60
+
61
+ if tokenizer.__class__.__name__ == "LlamaTokenizer":
62
+ tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN
63
+
64
+ if cfg.load_in_8bit:
65
+ model = prepare_model_for_int8_training(model)
66
+
67
+ lora_config = LoraConfig(
68
+ r=cfg.lora_r,
69
+ lora_alpha=cfg.lora_alpha,
70
+ target_modules=cfg.lora_target_modules,
71
+ lora_dropout=cfg.lora_dropout,
72
+ bias="none",
73
+ task_type="CAUSAL_LM",
74
+ )
75
+ model = get_peft_model(model, lora_config)
76
+ if cfg.ddp:
77
+ model.to(f"cuda:{cfg.local_rank}")
78
+
79
+ # TODO resume_from_checkpoint handling
80
+
81
+ model.print_trainable_parameters()
82
+ return model, tokenizer
83
+
84
+
85
+ def train(
86
+ config: Path = Path('configs/pythia_1_2B_alpaca.yml'),
87
+ **kwargs,
88
+ ):
89
+ # load the config from the yaml file
90
+ with open(config, 'r') as f:
91
+ cfg: AttrDict = AttrDict(yaml.load(f))
92
+ # if there are any options passed in the cli, if it is something that seems valid from the yaml,
93
+ # then overwrite the value
94
+ for k, v in enumerate(kwargs):
95
+ if k in cfg:
96
+ cfg.k = v
97
+
98
+ # setup some derived config / hyperparams
99
+ cfg.gradient_accumulation_steps = cfg.batch_size // cfg.micro_batch_size
100
+ cfg.device_map = "auto"
101
+ cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
102
+ cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
103
+ cfg.ddp = cfg.world_size != 1
104
+ if cfg.ddp:
105
+ cfg.device_map = {"": int(os.environ.get("LOCAL_RANK", 0))}
106
+ cfg.gradient_accumulation_steps = cfg.gradient_accumulation_steps // cfg.world_size
107
+ setup_wandb_env_vars(cfg)
108
+
109
+ # Load the model and tokenizer
110
+ model, tokenizer = load_model(cfg.base_model, cfg.model_type, cfg.tokenizer_type, cfg, adapter=cfg.adapter)
111
+ datasets = []
112
+ for d in cfg.datasets:
113
+ ds: IterableDataset = load_dataset("json", data_files=d.path, streaming=True, num_proc=4, split=None)
114
+ if d.type == "alpaca":
115
+ ds_strategy = AlpacaPromptTokenizingStrategy(AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
116
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
117
+ datasets.append(ds_wrapper)
118
+ elif d.type == "gpteacher":
119
+ ds_strategy = GPTeacherPromptTokenizingStrategy(GPTeacherPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
120
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
121
+ datasets.append(ds_wrapper)
122
+ elif d.type == "sharegpt":
123
+ ds_strategy = ShareGPTPromptTokenizingStrategy(ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len)
124
+ ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
125
+ datasets.append(ds_wrapper)
126
+
127
+
128
+ if __name__ == "__main__":
129
+ fire.Fire(train)
setup.cfg ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [metadata]
2
+ name = axolotl
3
+ version = 0.1.0
4
+ description = You know you're going to axolotl questions
5
+ author = Wing Lian
6
+ author_email = wing.lian@gmail.com
7
+ license = MIT
8
+
9
+ [options]
10
+ package_dir =
11
+ =src
12
+ packages = find:
13
+ install_requires =
14
+ transformers @ git+https://github.com/huggingface/transformers.git@main
15
+ peft @ git+https://github.com/huggingface/peft.git@main
16
+ attrdict
17
+ fire
18
+ PyYAML == 6.0
19
+ black
20
+
21
+ [options.packages.find]
22
+ where = src
23
+
src/axolotl/__init__.py ADDED
File without changes
src/axolotl/convert.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+
4
+
5
+ class FileReader:
6
+ def read(self, file_path):
7
+ with open(file_path, "r") as file:
8
+ return file.read()
9
+
10
+
11
+ class FileWriter:
12
+ def __init__(self, file_path):
13
+ self.file_path = file_path
14
+
15
+ def write(self, content):
16
+ with open(self.file_path, "w") as file:
17
+ file.write(content)
18
+
19
+
20
+ class StdoutWriter:
21
+ def write(self, content):
22
+ sys.stdout.write(content)
23
+ sys.stdout.write("\n")
24
+
25
+
26
+ class JsonParser:
27
+ def parse(self, content):
28
+ return json.loads(content)
29
+
30
+
31
+ class JsonlSerializer:
32
+ def serialize(self, data):
33
+ lines = [json.dumps(item) for item in data]
34
+ return "\n".join(lines)
35
+
36
+
37
+ class JsonToJsonlConverter:
38
+ def __init__(self, file_reader, file_writer, json_parser, jsonl_serializer):
39
+ self.file_reader = file_reader
40
+ self.file_writer = file_writer
41
+ self.json_parser = json_parser
42
+ self.jsonl_serializer = jsonl_serializer
43
+
44
+ def convert(self, input_file_path, output_file_path):
45
+ content = self.file_reader.read(input_file_path)
46
+ data = self.json_parser.parse(content)
47
+ jsonl_content = self.jsonl_serializer.serialize(data)
48
+ self.file_writer.write(jsonl_content)
49
+
50
+
src/axolotl/datasets.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ from datasets import IterableDataset
5
+ from .prompt_tokenizers import PromptTokenizingStrategy
6
+
7
+
8
+ # We want this to be a wrapper for an existing dataset that we have loaded
9
+ # lets use the concept of middlewares to wrap each dataset, for example
10
+ # ConstantLengthDataset(ShuffledDataset([TokenizedPromptDataset(alpaca_dataset)]))
11
+ # let's check to ensure we don't truncate an item in the middle, we'll use
12
+ # the collators later on to pad the datasets
13
+
14
+
15
+ class TokenizedPromptDataset(IterableDataset):
16
+ def __init__(
17
+ self,
18
+ prompt_tokenizer: PromptTokenizingStrategy,
19
+ dataset: IterableDataset,
20
+ ):
21
+ self.prompt_tokenizer = prompt_tokenizer
22
+ self.dataset = dataset
23
+
24
+ def __iter__(self):
25
+ iterator = iter(self.dataset)
26
+ yield self.prompt_tokenizer.tokenize_prompt(next(iterator))
27
+
28
+
29
+ class ConstantLengthDataset(IterableDataset):
30
+ """
31
+ Iterable dataset that returns constant length chunks of tokens from stream of text files.
32
+ Args:
33
+ tokenizer (Tokenizer): The processor used for proccessing the data.
34
+ dataset (dataset.Dataset): Dataset with text files.
35
+ infinite (bool): If True the iterator is reset after dataset reaches end else stops.
36
+ seq_length (int): Length of token sequences to return.
37
+ chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ tokenizer,
43
+ datasets,
44
+ infinite=False,
45
+ seq_length=2048,
46
+ num_of_sequences=1024,
47
+ chars_per_token=3.6,
48
+ ):
49
+ self.tokenizer = tokenizer
50
+ self.concat_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id else args.eos_token_id
51
+ self.datasets: List[IterableDataset] = datasets
52
+ self.seq_length = seq_length
53
+ self.infinite = infinite
54
+ self.current_size = 0
55
+ self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
56
+
57
+ def __iter__(self):
58
+ iterator = iter(self.datasets)
59
+ more_examples = True
60
+ while more_examples:
61
+ buffer, buffer_len = [], 0
62
+ while True:
63
+ if buffer_len >= self.max_buffer_size:
64
+ break
65
+ try:
66
+ buffer.append(next(iterator))
67
+ buffer_len += len(buffer[-1])
68
+ except StopIteration:
69
+ if self.infinite:
70
+ iterator = iter(self.datasets)
71
+ else:
72
+ more_examples = False
73
+ break
74
+ tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
75
+ all_token_ids = []
76
+ for tokenized_input in tokenized_inputs:
77
+ all_token_ids.extend(tokenized_input + [self.concat_token_id])
78
+ for i in range(0, len(all_token_ids), self.seq_length):
79
+ input_ids = all_token_ids[i : i + self.seq_length]
80
+ if len(input_ids) == self.seq_length:
81
+ self.current_size += 1
82
+ yield {
83
+ "input_ids": torch.LongTensor(input_ids),
84
+ "labels": torch.LongTensor(input_ids),
85
+ "attention_masks": torch.LongTensor(input_ids),
86
+ }
src/axolotl/prompt_tokenizers.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ from transformers import PreTrainedTokenizer
4
+
5
+ IGNORE_INDEX = -100
6
+ LLAMA_DEFAULT_PAD_TOKEN = "[PAD]"
7
+ LLAMA_DEFAULT_EOS_TOKEN = "</s>"
8
+ LLAMA_DEFAULT_BOS_TOKEN = "<s>"
9
+ LLAMA_DEFAULT_UNK_TOKEN = "<unk>"
10
+
11
+
12
+ class PromptTokenizingStrategy(abc.ABC):
13
+ def __init__(
14
+ self,
15
+ prompter,
16
+ tokenizer,
17
+ train_on_inputs: bool = False,
18
+ sequence_len: int = 2048,
19
+ ):
20
+ self.prompter = prompter
21
+ self.tokenizer: PreTrainedTokenizer = tokenizer
22
+ self.train_on_inputs = train_on_inputs
23
+ self.sequence_len = sequence_len
24
+
25
+ @abc.abstractmethod
26
+ def tokenize_prompt(self, prompt):
27
+ pass
28
+
29
+
30
+ class AlpacaPromptTokenizingStrategy(PromptTokenizingStrategy):
31
+ def tokenize_prompt(self, prompt):
32
+ full_prompt = self._tokenize_full_prompt(prompt)
33
+ tokenized_full_prompt = self._tokenize(full_prompt)
34
+ if not self.train_on_inputs:
35
+ user_prompt = self.prompter.generate_prompt(
36
+ prompt["instruction"], prompt["input"]
37
+ )
38
+ tokenized_user_prompt = self._tokenize(user_prompt, add_eos_token=False)
39
+ user_prompt_len = len(tokenized_user_prompt["input_ids"])
40
+ # TODO this could be sped up using numpy array slicing
41
+ tokenized_full_prompt["labels"] = [-100] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
42
+
43
+ return tokenized_full_prompt
44
+
45
+ def _tokenize_full_prompt(self, prompt):
46
+ return self.prompter.generate_prompt(
47
+ prompt["instruction"],
48
+ prompt["input"],
49
+ prompt["output"],
50
+ )
51
+
52
+ def _tokenize(self, prompt, add_eos_token=True):
53
+ result = self.tokenizer(
54
+ prompt,
55
+ truncation=True,
56
+ max_length=self.sequence_len,
57
+ padding=False,
58
+ return_tensors=None,
59
+ )
60
+ if (
61
+ result["input_ids"][-1] != self.tokenizer.eos_token_id
62
+ and len(result["input_ids"]) < self.sequence_len
63
+ and add_eos_token
64
+ ):
65
+ result["input_ids"].append(self.tokenizer.eos_token_id)
66
+ result["attention_mask"].append(1)
67
+
68
+ result["labels"] = result["input_ids"].copy()
69
+ return result
70
+
71
+
72
+ class GPTeacherPromptTokenizingStrategy(AlpacaPromptTokenizingStrategy):
73
+ def _tokenize_full_prompt(self, prompt):
74
+ return self.prompter.generate_prompt(
75
+ prompt["instruction"],
76
+ prompt["input"],
77
+ prompt["response"],
78
+ )
79
+
80
+
81
+ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
82
+ def tokenize_prompt(self, prompt):
83
+ pass
src/axolotl/prompters.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ class AlpacaPrompter:
2
+ pass
3
+
4
+
5
+ class ShareGPTPrompter:
6
+ pass
7
+
8
+
9
+ class GPTeacherPrompter:
10
+ pass