File size: 20,834 Bytes
302920f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 |
# Note: These tests were copied from test_common_gpu.py and test_gpu_examples.py as they can run on CPU too.
#
# Copyright 2025-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import os
import tempfile
import unittest
import pytest
import torch
from accelerate.utils.memory import clear_device_cache
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from peft import (
AdaLoraConfig,
LoraConfig,
OFTConfig,
PeftModel,
get_peft_model,
prepare_model_for_kbit_training,
)
from peft.tuners.lora import GPTQLoraLinear
from peft.utils import SAFETENSORS_WEIGHTS_NAME, infer_device
from .testing_utils import (
device_count,
load_dataset_english_quotes,
require_gptqmodel,
require_optimum,
require_torch_multi_accelerator,
)
@require_gptqmodel
class PeftGPTQModelCommonTests(unittest.TestCase):
r"""
A common tester to run common operations that are performed on GPU/CPU such as generation, loading in 8bit, etc.
"""
def setUp(self):
self.causal_lm_model_id = "facebook/opt-350m"
self.device = infer_device()
def tearDown(self):
r"""
Efficient mechanism to free GPU memory after each test. Based on
https://github.com/huggingface/transformers/issues/21094
"""
clear_device_cache(garbage_collection=True)
gc.collect()
def test_lora_gptq_quantization_from_pretrained_safetensors(self):
r"""
Tests that the gptqmodel quantization using LoRA works as expected with safetensors weights.
"""
from transformers import GPTQConfig
model_id = "marcsun13/opt-350m-gptq-4bit"
quantization_config = GPTQConfig(bits=4, use_exllama=False)
kwargs = {
"pretrained_model_name_or_path": model_id,
"torch_dtype": torch.float16,
"device_map": "auto",
"quantization_config": quantization_config,
}
model = AutoModelForCausalLM.from_pretrained(**kwargs)
model = prepare_model_for_kbit_training(model)
config = LoraConfig(task_type="CAUSAL_LM")
peft_model = get_peft_model(model, config)
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
with tempfile.TemporaryDirectory() as tmp_dir:
peft_model.save_pretrained(tmp_dir)
model = AutoModelForCausalLM.from_pretrained(**kwargs)
model = PeftModel.from_pretrained(model, tmp_dir)
model = prepare_model_for_kbit_training(model)
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
# loading a 2nd adapter works, #1239
model.load_adapter(tmp_dir, "adapter2")
model.set_adapter("adapter2")
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
# check that both adapters are in the same layer
assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A
assert "adapter2" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.lora_A
def test_oft_gptq_quantization_from_pretrained_safetensors(self):
r"""
Tests that the gptqmodel quantization using OFT works as expected with safetensors weights.
"""
from transformers import GPTQConfig
model_id = "marcsun13/opt-350m-gptq-4bit"
quantization_config = GPTQConfig(bits=4, use_exllama=False)
kwargs = {
"pretrained_model_name_or_path": model_id,
"torch_dtype": torch.float16,
"device_map": "auto",
"quantization_config": quantization_config,
}
model = AutoModelForCausalLM.from_pretrained(**kwargs)
model = prepare_model_for_kbit_training(model)
config = OFTConfig(task_type="CAUSAL_LM")
peft_model = get_peft_model(model, config)
peft_model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
with tempfile.TemporaryDirectory() as tmp_dir:
peft_model.save_pretrained(tmp_dir)
model = AutoModelForCausalLM.from_pretrained(**kwargs)
model = PeftModel.from_pretrained(model, tmp_dir)
model = prepare_model_for_kbit_training(model)
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
# loading a 2nd adapter works, #1239
model.load_adapter(tmp_dir, "adapter2")
model.set_adapter("adapter2")
model.generate(input_ids=torch.LongTensor([[0, 2, 3, 1]]).to(peft_model.device))
# check that both adapters are in the same layer
assert "default" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.oft_R
assert "adapter2" in model.base_model.model.model.decoder.layers[0].self_attn.q_proj.oft_R
@require_gptqmodel
@require_optimum
class PeftGPTQModelTests(unittest.TestCase):
r"""
GPTQ + peft tests
"""
def setUp(self):
from transformers import GPTQConfig
self.causal_lm_model_id = "marcsun13/opt-350m-gptq-4bit"
self.quantization_config = GPTQConfig(bits=4, backend="auto_trainable")
self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
def tearDown(self):
r"""
Efficient mechanism to free GPU memory after each test. Based on
https://github.com/huggingface/transformers/issues/21094
"""
clear_device_cache(garbage_collection=True)
def _check_inference_finite(self, model, batch):
# try inference without Trainer class
training = model.training
model.eval()
output = model(**batch.to(model.device))
assert torch.isfinite(output.logits).all()
model.train(training)
def test_causal_lm_training(self):
r"""
Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
correctly.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=self.quantization_config,
)
model = prepare_model_for_kbit_training(model)
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
data = load_dataset_english_quotes()
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.cpu().save_pretrained(tmp_dir)
assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
def test_oft_causal_lm_training(self):
r"""
Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set
correctly.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=self.quantization_config,
)
model = prepare_model_for_kbit_training(model)
config = OFTConfig(
r=0,
oft_block_size=8,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
data = load_dataset_english_quotes()
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.cpu().save_pretrained(tmp_dir)
assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
@pytest.mark.single_gpu_tests
def test_adalora_causalLM(self):
r"""
Tests the gptq training with adalora
"""
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=self.quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
model = prepare_model_for_kbit_training(model)
peft_config = AdaLoraConfig(
total_step=40,
init_r=6,
target_r=4,
tinit=10,
tfinal=20,
deltaT=5,
beta1=0.3,
beta2=0.3,
orth_reg_weight=0.2,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, peft_config)
data = load_dataset_english_quotes()
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True)
self._check_inference_finite(model, batch)
with tempfile.TemporaryDirectory() as tmp_dir:
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.cpu().save_pretrained(tmp_dir)
assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
@pytest.mark.multi_gpu_tests
@require_torch_multi_accelerator
def test_causal_lm_training_multi_accelerator(self):
r"""
Test the CausalLM training on a multi-accelerator device. The test would simply fail if the adapters are not
set correctly.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=self.quantization_config,
)
assert set(model.hf_device_map.values()) == set(range(device_count))
model = prepare_model_for_kbit_training(model)
setattr(model, "model_parallel", True)
setattr(model, "is_parallelizable", True)
config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "v_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
data = load_dataset_english_quotes()
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.cpu().save_pretrained(tmp_dir)
assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
@pytest.mark.multi_gpu_tests
@require_torch_multi_accelerator
def test_oft_causal_lm_training_multi_accelerator(self):
r"""
Test the CausalLM training on a multi-accelerator device. The test would simply fail if the adapters are not
set correctly.
"""
with tempfile.TemporaryDirectory() as tmp_dir:
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=self.quantization_config,
)
assert set(model.hf_device_map.values()) == set(range(device_count))
model = prepare_model_for_kbit_training(model)
setattr(model, "model_parallel", True)
setattr(model, "is_parallelizable", True)
config = OFTConfig(
r=0,
oft_block_size=8,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)
data = load_dataset_english_quotes()
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()
model.cpu().save_pretrained(tmp_dir)
assert "adapter_config.json" in os.listdir(tmp_dir)
assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir)
# assert loss is not None
assert trainer.state.log_history[-1]["train_loss"] is not None
def test_non_default_adapter_name(self):
# See issue 1346
config = LoraConfig(
r=16,
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM",
)
# default adapter name
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=self.quantization_config,
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, config)
n_trainable_default, n_total_default = model.get_nb_trainable_parameters()
# other adapter name
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=self.quantization_config,
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, config, adapter_name="other")
n_trainable_other, n_total_other = model.get_nb_trainable_parameters()
assert n_trainable_other > 0
# sanity check
assert n_trainable_default == n_trainable_other
assert n_total_default == n_total_other
def test_oft_non_default_adapter_name(self):
# See issue 1346
config = OFTConfig(
r=0,
oft_block_size=8,
target_modules=["q_proj", "v_proj"],
task_type="CAUSAL_LM",
)
# default adapter name
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=self.quantization_config,
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, config)
n_trainable_default, n_total_default = model.get_nb_trainable_parameters()
# other adapter name
model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=self.quantization_config,
)
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, config, adapter_name="other")
n_trainable_other, n_total_other = model.get_nb_trainable_parameters()
assert n_trainable_other > 0
# sanity check
assert n_trainable_default == n_trainable_other
assert n_total_default == n_total_other
def test_load_lora(self):
model_id = "ModelCloud/Llama-3.2-1B-gptqmodel-ci-4bit"
adapter_id = "ModelCloud/Llama-3.2-1B-gptqmodel-ci-4bit-lora"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
model.load_adapter(adapter_id)
# assert dynamic rank
v_proj_module = model.model.layers[5].self_attn.v_proj
assert isinstance(v_proj_module, GPTQLoraLinear)
assert v_proj_module.lora_A["default"].weight.data.shape[0] == 128
assert v_proj_module.lora_B["default"].weight.data.shape[1] == 128
gate_proj_module = model.model.layers[5].mlp.gate_proj
assert isinstance(gate_proj_module, GPTQLoraLinear)
assert gate_proj_module.lora_A["default"].weight.data.shape[0] == 256
assert gate_proj_module.lora_B["default"].weight.data.shape[1] == 256
tokenizer = AutoTokenizer.from_pretrained(model_id)
inp = tokenizer("Capital of France is", return_tensors="pt").to(model.device)
tokens = model.generate(**inp)[0]
result = tokenizer.decode(tokens)
assert "paris" in result.lower()
|