fix evals (#447)
Browse files
src/axolotl/monkeypatch/llama_attn_hijack_flash.py
CHANGED
@@ -169,7 +169,7 @@ def flashattn_forward(
|
|
169 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
170 |
|
171 |
output = flash_attn_varlen_qkvpacked_func(
|
172 |
-
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=
|
173 |
)
|
174 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
175 |
elif query_states.shape == key_states.shape:
|
|
|
169 |
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
170 |
|
171 |
output = flash_attn_varlen_qkvpacked_func(
|
172 |
+
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
173 |
)
|
174 |
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
175 |
elif query_states.shape == key_states.shape:
|
src/axolotl/utils/models.py
CHANGED
@@ -438,7 +438,7 @@ def load_llama_adapter(model, cfg):
|
|
438 |
)
|
439 |
|
440 |
if cfg.lora_model_dir:
|
441 |
-
LOG.
|
442 |
model = PeftModel.from_pretrained(
|
443 |
model,
|
444 |
cfg.lora_model_dir,
|
@@ -500,6 +500,7 @@ def load_lora(model, cfg):
|
|
500 |
)
|
501 |
|
502 |
if cfg.lora_model_dir:
|
|
|
503 |
model = PeftModel.from_pretrained(
|
504 |
model,
|
505 |
cfg.lora_model_dir,
|
|
|
438 |
)
|
439 |
|
440 |
if cfg.lora_model_dir:
|
441 |
+
LOG.debug("Loading pretained PEFT - llama_adapter")
|
442 |
model = PeftModel.from_pretrained(
|
443 |
model,
|
444 |
cfg.lora_model_dir,
|
|
|
500 |
)
|
501 |
|
502 |
if cfg.lora_model_dir:
|
503 |
+
LOG.debug("Loading pretained PEFT - LoRA")
|
504 |
model = PeftModel.from_pretrained(
|
505 |
model,
|
506 |
cfg.lora_model_dir,
|
src/axolotl/utils/trainer.py
CHANGED
@@ -14,12 +14,15 @@ import bitsandbytes as bnb
|
|
14 |
import numpy as np
|
15 |
import torch.cuda
|
16 |
import transformers
|
17 |
-
from datasets import set_caching_enabled
|
18 |
from torch import nn
|
19 |
from torch.optim.lr_scheduler import OneCycleLR
|
20 |
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
21 |
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
22 |
-
from transformers.trainer_pt_utils import
|
|
|
|
|
|
|
23 |
|
24 |
from axolotl.utils.callbacks import (
|
25 |
GPUStatsCallback,
|
@@ -171,6 +174,18 @@ class AxolotlTrainer(Trainer):
|
|
171 |
)
|
172 |
return super()._get_train_sampler()
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
175 |
if self.args.sample_packing:
|
176 |
train_sampler = self._get_train_sampler()
|
@@ -188,27 +203,28 @@ class AxolotlTrainer(Trainer):
|
|
188 |
)
|
189 |
return super().get_train_dataloader()
|
190 |
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
212 |
|
213 |
def compute_loss(self, model, inputs, return_outputs=False):
|
214 |
# use one's weighted cross entropy loss calc
|
|
|
14 |
import numpy as np
|
15 |
import torch.cuda
|
16 |
import transformers
|
17 |
+
from datasets import Dataset, set_caching_enabled
|
18 |
from torch import nn
|
19 |
from torch.optim.lr_scheduler import OneCycleLR
|
20 |
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
21 |
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
22 |
+
from transformers.trainer_pt_utils import (
|
23 |
+
SequentialDistributedSampler,
|
24 |
+
get_parameter_names,
|
25 |
+
)
|
26 |
|
27 |
from axolotl.utils.callbacks import (
|
28 |
GPUStatsCallback,
|
|
|
174 |
)
|
175 |
return super()._get_train_sampler()
|
176 |
|
177 |
+
def _get_eval_sampler(
|
178 |
+
self, eval_dataset: Dataset
|
179 |
+
) -> Optional[torch.utils.data.Sampler]:
|
180 |
+
if self.args.world_size > 1 and self.args.sample_packing:
|
181 |
+
return SequentialDistributedSampler(
|
182 |
+
eval_dataset,
|
183 |
+
num_replicas=self.args.world_size,
|
184 |
+
rank=self.args.process_index,
|
185 |
+
batch_size=self.args.per_device_eval_batch_size,
|
186 |
+
)
|
187 |
+
return super()._get_eval_sampler()
|
188 |
+
|
189 |
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
190 |
if self.args.sample_packing:
|
191 |
train_sampler = self._get_train_sampler()
|
|
|
203 |
)
|
204 |
return super().get_train_dataloader()
|
205 |
|
206 |
+
def get_eval_dataloader(
|
207 |
+
self, eval_dataset: Optional[Dataset] = None
|
208 |
+
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
209 |
+
if self.args.sample_packing:
|
210 |
+
eval_dataset = (
|
211 |
+
eval_dataset if eval_dataset is not None else self.eval_dataset
|
212 |
+
)
|
213 |
+
|
214 |
+
eval_sampler = self._get_eval_sampler(eval_dataset)
|
215 |
+
return self.accelerator.prepare(
|
216 |
+
MultipackDistributedDataloader(
|
217 |
+
eval_dataset,
|
218 |
+
batch_size=self.args.eval_batch_size,
|
219 |
+
seq_max_length=self.args.max_seq_length,
|
220 |
+
collate_fn=self.data_collator,
|
221 |
+
sampler=eval_sampler,
|
222 |
+
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
223 |
+
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
224 |
+
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
225 |
+
)
|
226 |
+
)
|
227 |
+
return super().get_eval_dataloader(eval_dataset)
|
228 |
|
229 |
def compute_loss(self, model, inputs, return_outputs=False):
|
230 |
# use one's weighted cross entropy loss calc
|