winglian commited on
Commit
4a17a4c
1 Parent(s): 097d367

fix dataset handling, support galactica

Browse files
configs/galactica_1_3B.yml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ base_model: facebook/galactica-1.3b
2
+ model_type: AutoModelForCausalLM
3
+ tokenizer_type: AutoTokenizer
4
+ load_in_8bit: false
5
+ datasets:
6
+ - path: tatsu-lab/alpaca
7
+ type: alpaca
8
+ dataset_prepared_path: last_run_prepared
9
+ val_set_size: 0.1
10
+ adapter:
11
+ lora_model_dir:
12
+ sequence_len: 1024
13
+ max_packed_sequence_len: 1024
14
+ lora_r: 8
15
+ lora_alpha: 16
16
+ lora_dropout: 0.05
17
+ lora_target_modules:
18
+ - q_proj
19
+ - v_proj
20
+ lora_fan_in_fan_out: false
21
+ wandb_project:
22
+ wandb_watch:
23
+ wandb_run_id:
24
+ wandb_log_model: checkpoint
25
+ output_dir: ./lora-llama-alpaca
26
+ batch_size: 32
27
+ micro_batch_size: 16
28
+ num_epochs: 3
29
+ learning_rate: 0.00003
30
+ train_on_inputs: false
31
+ group_by_length: false
32
+ bf16: false
33
+ tf32: false
34
+ early_stopping_patience:
35
+ resume_from_checkpoint:
36
+ local_rank:
37
+ special_tokens:
38
+ pad_token: "[PAD]"
39
+ bos_token: "<s>"
40
+ eos_token: "</s>"
41
+ unk_token: "<unk>"
src/axolotl/utils/data.py CHANGED
@@ -31,7 +31,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
31
  ds_hash = str(
32
  md5(
33
  (
34
- str(max_packed_sequence_len)
35
  + "@"
36
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
37
  ).encode("utf-8")
@@ -114,21 +114,24 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
114
  datasets.append(ds_wrapper)
115
  else:
116
  logging.error(f"unhandled prompt tokenization strategy: {d.type}")
117
- logging.info("merging and shuffling master dataset")
118
 
119
- dataset = concatenate_datasets(datasets).shuffle(seed=42)
 
 
 
120
  if cfg.local_rank == 0:
121
  logging.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
122
  dataset.save_to_disk(prepared_ds_path)
123
 
124
- if cfg.max_packed_sequence_len is not None:
125
- constant_len_dataset = ConstantLengthDataset(
126
- tokenizer,
127
- [dataset],
128
- seq_length=max_packed_sequence_len,
129
- )
130
- logging.info("packing master dataset")
131
- dataset = Dataset.from_list([_ for _ in constant_len_dataset])
132
 
133
  if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
134
  logging.info(f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards")
 
31
  ds_hash = str(
32
  md5(
33
  (
34
+ str(cfg.sequence_len)
35
  + "@"
36
  + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
37
  ).encode("utf-8")
 
114
  datasets.append(ds_wrapper)
115
  else:
116
  logging.error(f"unhandled prompt tokenization strategy: {d.type}")
117
+ logging.info("tokenizing, merging, and shuffling master dataset")
118
 
119
+ samples = []
120
+ for d in datasets:
121
+ samples = samples + [i for i in d]
122
+ dataset = Dataset.from_list(samples).shuffle(seed=42)
123
  if cfg.local_rank == 0:
124
  logging.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
125
  dataset.save_to_disk(prepared_ds_path)
126
 
127
+ if cfg.max_packed_sequence_len is not None:
128
+ constant_len_dataset = ConstantLengthDataset(
129
+ tokenizer,
130
+ [dataset],
131
+ seq_length=max_packed_sequence_len,
132
+ )
133
+ logging.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
134
+ dataset = Dataset.from_list([_ for _ in constant_len_dataset])
135
 
136
  if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
137
  logging.info(f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards")
src/axolotl/utils/models.py CHANGED
@@ -161,6 +161,10 @@ def load_model(
161
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
162
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
163
 
 
 
 
 
164
  if load_in_8bit and not cfg.load_4bit:
165
  logging.info("converting model w/ prepare_model_for_int8_training")
166
  model = prepare_model_for_int8_training(model)
 
161
  tokenizer.add_special_tokens({"pad_token": "[PAD]"})
162
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
163
 
164
+ if cfg.special_tokens:
165
+ for k, v in cfg.special_tokens.items():
166
+ setattr(tokenizer, k, v)
167
+
168
  if load_in_8bit and not cfg.load_4bit:
169
  logging.info("converting model w/ prepare_model_for_int8_training")
170
  model = prepare_model_for_int8_training(model)