|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Union |
|
|
|
|
|
import pytest |
|
|
import torch |
|
|
from datasets import load_dataset |
|
|
from torch.utils.data import Dataset |
|
|
from tqdm import tqdm |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
DataCollatorForLanguageModeling, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
) |
|
|
|
|
|
from peft import CPTConfig, TaskType, get_peft_model |
|
|
|
|
|
|
|
|
TEMPLATE = {"input": "input: {}", "intra_seperator": " ", "output": "output: {}", "inter_seperator": "\n"} |
|
|
|
|
|
MODEL_NAME = "hf-internal-testing/tiny-random-OPTForCausalLM" |
|
|
MAX_INPUT_LENGTH = 1024 |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module") |
|
|
def global_tokenizer(): |
|
|
"""Load the tokenizer fixture for the model.""" |
|
|
|
|
|
return AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right") |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module") |
|
|
def config_text(): |
|
|
"""Load the SST2 dataset and prepare it for testing.""" |
|
|
config = CPTConfig( |
|
|
cpt_token_ids=[0, 1, 2, 3, 4, 5, 6, 7], |
|
|
cpt_mask=[1, 1, 1, 1, 1, 1, 1, 1], |
|
|
cpt_tokens_type_mask=[1, 2, 2, 2, 3, 3, 3, 4], |
|
|
opt_weighted_loss_type="decay", |
|
|
opt_loss_decay_factor=0.95, |
|
|
opt_projection_epsilon=0.2, |
|
|
opt_projection_format_epsilon=0.1, |
|
|
tokenizer_name_or_path=MODEL_NAME, |
|
|
) |
|
|
return config |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module") |
|
|
def config_random(): |
|
|
"""Load the SST2 dataset and prepare it for testing.""" |
|
|
config = CPTConfig( |
|
|
opt_weighted_loss_type="decay", |
|
|
opt_loss_decay_factor=0.95, |
|
|
opt_projection_epsilon=0.2, |
|
|
opt_projection_format_epsilon=0.1, |
|
|
tokenizer_name_or_path=MODEL_NAME, |
|
|
) |
|
|
return config |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module") |
|
|
def sst_data(): |
|
|
"""Load the SST2 dataset and prepare it for testing.""" |
|
|
data = load_dataset("glue", "sst2") |
|
|
|
|
|
def add_string_labels(example): |
|
|
if example["label"] == 0: |
|
|
example["label_text"] = "negative" |
|
|
elif example["label"] == 1: |
|
|
example["label_text"] = "positive" |
|
|
return example |
|
|
|
|
|
train_dataset = data["train"].select(range(4)).map(add_string_labels) |
|
|
test_dataset = data["validation"].select(range(10)).map(add_string_labels) |
|
|
|
|
|
return {"train": train_dataset, "test": test_dataset} |
|
|
|
|
|
|
|
|
@pytest.fixture(scope="module") |
|
|
def collator(global_tokenizer): |
|
|
class CPTDataCollatorForLanguageModeling(DataCollatorForLanguageModeling): |
|
|
def __init__(self, tokenizer, training=True, mlm=False): |
|
|
super().__init__(tokenizer, mlm=mlm) |
|
|
self.training = training |
|
|
self.tokenizer.add_special_tokens({"pad_token": "[PAD]"}) |
|
|
|
|
|
def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]: |
|
|
|
|
|
list_sample_mask = [] |
|
|
for i in range(len(examples)): |
|
|
if "sample_mask" in examples[i].keys(): |
|
|
list_sample_mask.append(examples[i].pop("sample_mask")) |
|
|
|
|
|
max_len = max(len(ex["input_ids"]) for ex in examples) |
|
|
|
|
|
def pad_sequence(sequence, max_len, pad_value=0): |
|
|
return sequence + [pad_value] * (max_len - len(sequence)) |
|
|
|
|
|
input_ids = torch.tensor([pad_sequence(ex["input_ids"], max_len) for ex in examples]) |
|
|
attention_mask = torch.tensor([pad_sequence(ex["attention_mask"], max_len) for ex in examples]) |
|
|
input_type_mask = torch.tensor([pad_sequence(ex["input_type_mask"], max_len) for ex in examples]) |
|
|
|
|
|
batch = {"input_ids": input_ids, "attention_mask": attention_mask, "input_type_mask": input_type_mask} |
|
|
|
|
|
tensor_sample_mask = batch["input_ids"].clone().long() |
|
|
tensor_sample_mask[:, :] = 0 |
|
|
for i in range(len(list_sample_mask)): |
|
|
tensor_sample_mask[i, : len(list_sample_mask[i])] = list_sample_mask[i] |
|
|
|
|
|
batch["labels"] = batch["input_ids"].clone() |
|
|
if not self.training: |
|
|
batch["sample_mask"] = tensor_sample_mask |
|
|
|
|
|
return batch |
|
|
|
|
|
collator = CPTDataCollatorForLanguageModeling(global_tokenizer, training=True, mlm=False) |
|
|
return collator |
|
|
|
|
|
|
|
|
def dataset(data, tokenizer): |
|
|
class CPTDataset(Dataset): |
|
|
def __init__(self, samples, tokenizer, template, max_length=MAX_INPUT_LENGTH): |
|
|
self.template = template |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
|
|
|
self.attention_mask = [] |
|
|
self.input_ids = [] |
|
|
self.input_type_mask = [] |
|
|
self.inter_seperator_ids = self._get_input_ids(template["inter_seperator"]) |
|
|
|
|
|
for sample_i in tqdm(samples): |
|
|
input_text, label = sample_i["sentence"], sample_i["label_text"] |
|
|
input_ids, attention_mask, input_type_mask = self.preprocess_sentence(input_text, label) |
|
|
|
|
|
self.input_ids.append(input_ids) |
|
|
self.attention_mask.append(attention_mask) |
|
|
self.input_type_mask.append(input_type_mask) |
|
|
|
|
|
def _get_input_ids(self, text): |
|
|
return self.tokenizer(text, add_special_tokens=False)["input_ids"] |
|
|
|
|
|
def preprocess_sentence(self, input_text, label): |
|
|
input_template_part_1_text, input_template_part_2_text = self.template["input"].split("{}") |
|
|
input_template_tokenized_part1 = self._get_input_ids(input_template_part_1_text) |
|
|
input_tokenized = self._get_input_ids(input_text) |
|
|
input_template_tokenized_part2 = self._get_input_ids(input_template_part_2_text) |
|
|
|
|
|
sep_tokenized = self._get_input_ids(self.template["intra_seperator"]) |
|
|
|
|
|
label_template_part_1, label_template_part_2 = self.template["output"].split("{}") |
|
|
label_template_part1_tokenized = self._get_input_ids(label_template_part_1) |
|
|
label_tokenized = self._get_input_ids(label) |
|
|
label_template_part2_tokenized = self._get_input_ids(label_template_part_2) |
|
|
|
|
|
eos = [self.tokenizer.eos_token_id] if self.tokenizer.eos_token_id is not None else [] |
|
|
input_ids = ( |
|
|
input_template_tokenized_part1 |
|
|
+ input_tokenized |
|
|
+ input_template_tokenized_part2 |
|
|
+ sep_tokenized |
|
|
+ label_template_part1_tokenized |
|
|
+ label_tokenized |
|
|
+ label_template_part2_tokenized |
|
|
+ eos |
|
|
) |
|
|
|
|
|
|
|
|
attention_mask = [1] * len(input_ids) |
|
|
input_type_mask = ( |
|
|
[1] * len(input_template_tokenized_part1) |
|
|
+ [2] * len(input_tokenized) |
|
|
+ [1] * len(input_template_tokenized_part2) |
|
|
+ [0] * len(sep_tokenized) |
|
|
+ [3] * len(label_template_part1_tokenized) |
|
|
+ [4] * len(label_tokenized) |
|
|
+ [3] * len(label_template_part2_tokenized) |
|
|
+ [0] * len(eos) |
|
|
) |
|
|
|
|
|
assert len(input_type_mask) == len(input_ids) == len(attention_mask) |
|
|
|
|
|
return input_ids, attention_mask, input_type_mask |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.input_ids) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
return { |
|
|
"input_ids": self.input_ids[idx], |
|
|
"attention_mask": self.attention_mask[idx], |
|
|
"input_type_mask": self.input_type_mask[idx], |
|
|
} |
|
|
|
|
|
dataset = CPTDataset(data, tokenizer, TEMPLATE) |
|
|
|
|
|
return dataset |
|
|
|
|
|
|
|
|
def test_model_initialization_text(global_tokenizer, config_text): |
|
|
"""Test model loading and PEFT model initialization.""" |
|
|
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
|
|
|
|
|
model = get_peft_model(base_model, config_text) |
|
|
assert model is not None, "PEFT model initialization failed" |
|
|
|
|
|
|
|
|
def test_model_initialization_random(global_tokenizer, config_random): |
|
|
"""Test model loading and PEFT model initialization.""" |
|
|
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
|
|
|
|
|
model = get_peft_model(base_model, config_random) |
|
|
assert model is not None, "PEFT model initialization failed" |
|
|
|
|
|
|
|
|
def test_model_initialization_wrong_task_type_warns(): |
|
|
|
|
|
msg = "CPTConfig only supports task_type = CAUSAL_LM, setting it automatically" |
|
|
with pytest.warns(FutureWarning, match=msg): |
|
|
config = CPTConfig(task_type=TaskType.SEQ_CLS) |
|
|
assert config.task_type == TaskType.CAUSAL_LM |
|
|
|
|
|
|
|
|
def test_model_training_random(sst_data, global_tokenizer, collator, config_random): |
|
|
"""Perform a short training run to verify the model and data integration.""" |
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
|
|
model = get_peft_model(base_model, config_random) |
|
|
emb = model.prompt_encoder.default.embedding.weight.data.clone().detach() |
|
|
training_args = TrainingArguments( |
|
|
output_dir="./results", |
|
|
per_device_train_batch_size=1, |
|
|
num_train_epochs=2, |
|
|
remove_unused_columns=False, |
|
|
save_strategy="no", |
|
|
logging_steps=1, |
|
|
) |
|
|
|
|
|
train_dataset = dataset(sst_data["train"], global_tokenizer) |
|
|
|
|
|
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=collator) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
assert torch.all(model.prompt_encoder.default.embedding.weight.data.clone().detach().cpu() == emb.cpu()) |
|
|
|
|
|
delta_emb = model.prompt_encoder.default.get_projection().clone().detach() |
|
|
norm_delta = delta_emb.norm(dim=1).cpu() |
|
|
epsilon = model.prompt_encoder.default.get_epsilon().cpu() |
|
|
|
|
|
assert torch.all(norm_delta <= epsilon) |
|
|
|
|
|
|
|
|
def test_model_batch_training_text(sst_data, global_tokenizer, collator, config_text): |
|
|
"""Perform a short training run to verify the model and data integration.""" |
|
|
|
|
|
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) |
|
|
model = get_peft_model(base_model, config_text) |
|
|
emb = model.prompt_encoder.default.embedding.weight.data.clone().detach() |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir="./results", |
|
|
per_device_train_batch_size=2, |
|
|
num_train_epochs=2, |
|
|
remove_unused_columns=False, |
|
|
save_strategy="no", |
|
|
logging_steps=1, |
|
|
) |
|
|
|
|
|
train_dataset = dataset(sst_data["train"], global_tokenizer) |
|
|
|
|
|
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset, data_collator=collator) |
|
|
|
|
|
trainer.train() |
|
|
|
|
|
assert torch.all(model.prompt_encoder.default.embedding.weight.data.clone().detach().cpu() == emb.cpu()) |
|
|
|
|
|
cpt_tokens_type_mask = torch.Tensor(config_text.cpt_tokens_type_mask).long() |
|
|
non_label_idx = (cpt_tokens_type_mask == 1) | (cpt_tokens_type_mask == 2) | (cpt_tokens_type_mask == 3) |
|
|
|
|
|
delta_emb = model.prompt_encoder.default.get_projection().clone().detach() |
|
|
norm_delta = delta_emb.norm(dim=1).cpu() |
|
|
epsilon = model.prompt_encoder.default.get_epsilon().cpu() |
|
|
|
|
|
assert torch.all(norm_delta <= epsilon) |
|
|
|
|
|
assert torch.all((norm_delta == 0) == (~non_label_idx)) |
|
|
|