winglian commited on
Commit
ee26281
1 Parent(s): 9d629d8

fix evals (#447)

Browse files
src/axolotl/monkeypatch/llama_attn_hijack_flash.py CHANGED
@@ -169,7 +169,7 @@ def flashattn_forward(
169
  qkv = rearrange(qkv, "b s ... -> (b s) ...")
170
 
171
  output = flash_attn_varlen_qkvpacked_func(
172
- qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=is_causal
173
  )
174
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
175
  elif query_states.shape == key_states.shape:
 
169
  qkv = rearrange(qkv, "b s ... -> (b s) ...")
170
 
171
  output = flash_attn_varlen_qkvpacked_func(
172
+ qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
173
  )
174
  output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
175
  elif query_states.shape == key_states.shape:
src/axolotl/utils/models.py CHANGED
@@ -438,7 +438,7 @@ def load_llama_adapter(model, cfg):
438
  )
439
 
440
  if cfg.lora_model_dir:
441
- LOG.info("Loading pretained LORA")
442
  model = PeftModel.from_pretrained(
443
  model,
444
  cfg.lora_model_dir,
@@ -500,6 +500,7 @@ def load_lora(model, cfg):
500
  )
501
 
502
  if cfg.lora_model_dir:
 
503
  model = PeftModel.from_pretrained(
504
  model,
505
  cfg.lora_model_dir,
 
438
  )
439
 
440
  if cfg.lora_model_dir:
441
+ LOG.debug("Loading pretained PEFT - llama_adapter")
442
  model = PeftModel.from_pretrained(
443
  model,
444
  cfg.lora_model_dir,
 
500
  )
501
 
502
  if cfg.lora_model_dir:
503
+ LOG.debug("Loading pretained PEFT - LoRA")
504
  model = PeftModel.from_pretrained(
505
  model,
506
  cfg.lora_model_dir,
src/axolotl/utils/trainer.py CHANGED
@@ -14,12 +14,15 @@ import bitsandbytes as bnb
14
  import numpy as np
15
  import torch.cuda
16
  import transformers
17
- from datasets import set_caching_enabled
18
  from torch import nn
19
  from torch.optim.lr_scheduler import OneCycleLR
20
  from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
21
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
22
- from transformers.trainer_pt_utils import get_parameter_names
 
 
 
23
 
24
  from axolotl.utils.callbacks import (
25
  GPUStatsCallback,
@@ -171,6 +174,18 @@ class AxolotlTrainer(Trainer):
171
  )
172
  return super()._get_train_sampler()
173
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
175
  if self.args.sample_packing:
176
  train_sampler = self._get_train_sampler()
@@ -188,27 +203,28 @@ class AxolotlTrainer(Trainer):
188
  )
189
  return super().get_train_dataloader()
190
 
191
- # def get_eval_dataloader(
192
- # self, eval_dataset: Optional[Dataset] = None
193
- # ) -> Union[DataLoader, MultipackDistributedDataloader]:
194
- # if self.args.sample_packing:
195
- # eval_dataset = (
196
- # eval_dataset if eval_dataset is not None else self.eval_dataset
197
- # )
198
- # eval_sampler = self._get_eval_sampler(eval_dataset)
199
- # return self.accelerator.prepare(
200
- # MultipackDistributedDataloader(
201
- # eval_dataset,
202
- # batch_size=self.args.eval_batch_size,
203
- # seq_max_length=self.args.max_seq_length,
204
- # collate_fn=self.data_collator,
205
- # sampler=eval_sampler,
206
- # packing_efficiency_estimate=self.args.sample_packing_efficiency,
207
- # sample_packing_seq_len_multiplier=self.args.eval_batch_size,
208
- # device_count=int(os.environ.get("WORLD_SIZE", 1)),
209
- # )
210
- # )
211
- # return super().get_eval_dataloader(eval_dataset)
 
212
 
213
  def compute_loss(self, model, inputs, return_outputs=False):
214
  # use one's weighted cross entropy loss calc
 
14
  import numpy as np
15
  import torch.cuda
16
  import transformers
17
+ from datasets import Dataset, set_caching_enabled
18
  from torch import nn
19
  from torch.optim.lr_scheduler import OneCycleLR
20
  from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
21
  from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
22
+ from transformers.trainer_pt_utils import (
23
+ SequentialDistributedSampler,
24
+ get_parameter_names,
25
+ )
26
 
27
  from axolotl.utils.callbacks import (
28
  GPUStatsCallback,
 
174
  )
175
  return super()._get_train_sampler()
176
 
177
+ def _get_eval_sampler(
178
+ self, eval_dataset: Dataset
179
+ ) -> Optional[torch.utils.data.Sampler]:
180
+ if self.args.world_size > 1 and self.args.sample_packing:
181
+ return SequentialDistributedSampler(
182
+ eval_dataset,
183
+ num_replicas=self.args.world_size,
184
+ rank=self.args.process_index,
185
+ batch_size=self.args.per_device_eval_batch_size,
186
+ )
187
+ return super()._get_eval_sampler()
188
+
189
  def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
190
  if self.args.sample_packing:
191
  train_sampler = self._get_train_sampler()
 
203
  )
204
  return super().get_train_dataloader()
205
 
206
+ def get_eval_dataloader(
207
+ self, eval_dataset: Optional[Dataset] = None
208
+ ) -> Union[DataLoader, MultipackDistributedDataloader]:
209
+ if self.args.sample_packing:
210
+ eval_dataset = (
211
+ eval_dataset if eval_dataset is not None else self.eval_dataset
212
+ )
213
+
214
+ eval_sampler = self._get_eval_sampler(eval_dataset)
215
+ return self.accelerator.prepare(
216
+ MultipackDistributedDataloader(
217
+ eval_dataset,
218
+ batch_size=self.args.eval_batch_size,
219
+ seq_max_length=self.args.max_seq_length,
220
+ collate_fn=self.data_collator,
221
+ sampler=eval_sampler,
222
+ packing_efficiency_estimate=self.args.sample_packing_efficiency,
223
+ sample_packing_seq_len_multiplier=self.args.eval_batch_size,
224
+ device_count=int(os.environ.get("WORLD_SIZE", 1)),
225
+ )
226
+ )
227
+ return super().get_eval_dataloader(eval_dataset)
228
 
229
  def compute_loss(self, model, inputs, return_outputs=False):
230
  # use one's weighted cross entropy loss calc