Ramzes / tests /test_cpt.py
Bordoglor's picture
Upload folder using huggingface_hub
302920f verified
# Copyright 2024-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Union
import pytest
import torch
from datasets import load_dataset
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from peft import CPTConfig, TaskType, get_peft_model
TEMPLATE = {"input": "input: {}", "intra_seperator": " ", "output": "output: {}", "inter_seperator": "\n"}
MODEL_NAME = "hf-internal-testing/tiny-random-OPTForCausalLM"
MAX_INPUT_LENGTH = 1024
@pytest.fixture(scope="module")
def global_tokenizer():
"""Load the tokenizer fixture for the model."""
return AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right")
@pytest.fixture(scope="module")
def config_text():
"""Load the SST2 dataset and prepare it for testing."""
config = CPTConfig(
cpt_token_ids=[0, 1, 2, 3, 4, 5, 6, 7], # Example token IDs for testing
cpt_mask=[1, 1, 1, 1, 1, 1, 1, 1],
cpt_tokens_type_mask=[1, 2, 2, 2, 3, 3, 3, 4],
opt_weighted_loss_type="decay",
opt_loss_decay_factor=0.95,
opt_projection_epsilon=0.2,
opt_projection_format_epsilon=0.1,
tokenizer_name_or_path=MODEL_NAME,
)
return config
@pytest.fixture(scope="module")
def config_random():
"""Load the SST2 dataset and prepare it for testing."""
config = CPTConfig(
opt_weighted_loss_type="decay",
opt_loss_decay_factor=0.95,
opt_projection_epsilon=0.2,
opt_projection_format_epsilon=0.1,
tokenizer_name_or_path=MODEL_NAME,
)
return config
@pytest.fixture(scope="module")
def sst_data():
"""Load the SST2 dataset and prepare it for testing."""
data = load_dataset("glue", "sst2")
def add_string_labels(example):
if example["label"] == 0:
example["label_text"] = "negative"
elif example["label"] == 1:
example["label_text"] = "positive"
return example
train_dataset = data["train"].select(range(4)).map(add_string_labels)
test_dataset = data["validation"].select(range(10)).map(add_string_labels)
return {"train": train_dataset, "test": test_dataset}
@pytest.fixture(scope="module")
def collator(global_tokenizer):
class CPTDataCollatorForLanguageModeling(DataCollatorForLanguageModeling):
def __init__(self, tokenizer, training=True, mlm=False):
super().__init__(tokenizer, mlm=mlm)
self.training = training
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # mk check why needed
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
# Handle dict or lists with proper padding and conversion to tensor.
list_sample_mask = []
for i in range(len(examples)):
if "sample_mask" in examples[i].keys():
list_sample_mask.append(examples[i].pop("sample_mask"))
max_len = max(len(ex["input_ids"]) for ex in examples)
def pad_sequence(sequence, max_len, pad_value=0):
return sequence + [pad_value] * (max_len - len(sequence))
input_ids = torch.tensor([pad_sequence(ex["input_ids"], max_len) for ex in examples])
attention_mask = torch.tensor([pad_sequence(ex["attention_mask"], max_len) for ex in examples])
input_type_mask = torch.tensor([pad_sequence(ex["input_type_mask"], max_len) for ex in examples])
batch = {"input_ids": input_ids, "attention_mask": attention_mask, "input_type_mask": input_type_mask}
tensor_sample_mask = batch["input_ids"].clone().long()
tensor_sample_mask[:, :] = 0
for i in range(len(list_sample_mask)):
tensor_sample_mask[i, : len(list_sample_mask[i])] = list_sample_mask[i]
batch["labels"] = batch["input_ids"].clone()
if not self.training:
batch["sample_mask"] = tensor_sample_mask
return batch
collator = CPTDataCollatorForLanguageModeling(global_tokenizer, training=True, mlm=False)
return collator
def dataset(data, tokenizer):
class CPTDataset(Dataset):
def __init__(self, samples, tokenizer, template, max_length=MAX_INPUT_LENGTH):
self.template = template
self.tokenizer = tokenizer
self.max_length = max_length
self.attention_mask = []
self.input_ids = []
self.input_type_mask = []
self.inter_seperator_ids = self._get_input_ids(template["inter_seperator"])
for sample_i in tqdm(samples):
input_text, label = sample_i["sentence"], sample_i["label_text"]
input_ids, attention_mask, input_type_mask = self.preprocess_sentence(input_text, label)
self.input_ids.append(input_ids)
self.attention_mask.append(attention_mask)
self.input_type_mask.append(input_type_mask)
def _get_input_ids(self, text):
return self.tokenizer(text, add_special_tokens=False)["input_ids"]
def preprocess_sentence(self, input_text, label):
input_template_part_1_text, input_template_part_2_text = self.template["input"].split("{}")
input_template_tokenized_part1 = self._get_input_ids(input_template_part_1_text)
input_tokenized = self._get_input_ids(input_text)
input_template_tokenized_part2 = self._get_input_ids(input_template_part_2_text)
sep_tokenized = self._get_input_ids(self.template["intra_seperator"])
label_template_part_1, label_template_part_2 = self.template["output"].split("{}")
label_template_part1_tokenized = self._get_input_ids(label_template_part_1)
label_tokenized = self._get_input_ids(label)
label_template_part2_tokenized = self._get_input_ids(label_template_part_2)
eos = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else []
input_ids = (
input_template_tokenized_part1
+ input_tokenized
+ input_template_tokenized_part2
+ sep_tokenized
+ label_template_part1_tokenized
+ label_tokenized
+ label_template_part2_tokenized
+ eos
)
# determine label tokens, to calculate loss only over them when labels_loss == True
attention_mask = [1] * len(input_ids)
input_type_mask = (
[1] * len(input_template_tokenized_part1)
+ [2] * len(input_tokenized)
+ [1] * len(input_template_tokenized_part2)
+ [0] * len(sep_tokenized)
+ [3] * len(label_template_part1_tokenized)
+ [4] * len(label_tokenized)
+ [3] * len(label_template_part2_tokenized)
+ [0] * len(eos)
)
assert len(input_type_mask) == len(input_ids) == len(attention_mask)
return input_ids, attention_mask, input_type_mask
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
"input_type_mask": self.input_type_mask[idx],
}
dataset = CPTDataset(data, tokenizer, TEMPLATE)
return dataset
def test_model_initialization_text(global_tokenizer, config_text):
"""Test model loading and PEFT model initialization."""
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model = get_peft_model(base_model, config_text)
assert model is not None, "PEFT model initialization failed"
def test_model_initialization_random(global_tokenizer, config_random):
"""Test model loading and PEFT model initialization."""
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model = get_peft_model(base_model, config_random)
assert model is not None, "PEFT model initialization failed"
def test_model_initialization_wrong_task_type_warns():
# TODO: adjust this test to check for an error with PEFT v0.18.0
msg = "CPTConfig only supports task_type = CAUSAL_LM, setting it automatically"
with pytest.warns(FutureWarning, match=msg):
config = CPTConfig(task_type=TaskType.SEQ_CLS)
assert config.task_type == TaskType.CAUSAL_LM
def test_model_training_random(sst_data, global_tokenizer, collator, config_random):
"""Perform a short training run to verify the model and data integration."""
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model = get_peft_model(base_model, config_random)
emb = model.prompt_encoder.default.embedding.weight.data.clone().detach()
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=1,
num_train_epochs=2,
remove_unused_columns=False,
save_strategy="no",
logging_steps=1,
)
train_dataset = dataset(sst_data["train"], global_tokenizer)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=collator)
trainer.train()
# Verify that the embedding tensor remains unchanged (frozen)
assert torch.all(model.prompt_encoder.default.embedding.weight.data.clone().detach().cpu() == emb.cpu())
delta_emb = model.prompt_encoder.default.get_projection().clone().detach()
norm_delta = delta_emb.norm(dim=1).cpu()
epsilon = model.prompt_encoder.default.get_epsilon().cpu()
# Verify that the change in tokens is constrained to epsilon
assert torch.all(norm_delta <= epsilon)
def test_model_batch_training_text(sst_data, global_tokenizer, collator, config_text):
"""Perform a short training run to verify the model and data integration."""
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
model = get_peft_model(base_model, config_text)
emb = model.prompt_encoder.default.embedding.weight.data.clone().detach()
training_args = TrainingArguments(
output_dir="./results",
per_device_train_batch_size=2,
num_train_epochs=2,
remove_unused_columns=False,
save_strategy="no",
logging_steps=1,
)
train_dataset = dataset(sst_data["train"], global_tokenizer)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=collator)
trainer.train()
# Verify that the embedding tensor remains unchanged (frozen)
assert torch.all(model.prompt_encoder.default.embedding.weight.data.clone().detach().cpu() == emb.cpu())
cpt_tokens_type_mask = torch.Tensor(config_text.cpt_tokens_type_mask).long()
non_label_idx = (cpt_tokens_type_mask == 1) | (cpt_tokens_type_mask == 2) | (cpt_tokens_type_mask == 3)
delta_emb = model.prompt_encoder.default.get_projection().clone().detach()
norm_delta = delta_emb.norm(dim=1).cpu()
epsilon = model.prompt_encoder.default.get_epsilon().cpu()
# Verify that the change in tokens is constrained to epsilon
assert torch.all(norm_delta <= epsilon)
# Ensure that label tokens remain unchanged
assert torch.all((norm_delta == 0) == (~non_label_idx))