|
import torch |
|
import sys |
|
import argparse |
|
import os |
|
sys.path.append("..") |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling |
|
from datasets import load_dataset |
|
from numpy.random import default_rng |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
MODEL_NAME_SAVE = "Llama-3.2-3B" |
|
FILE_SAMPLE_SIZE = 200 |
|
|
|
def get_perplexities(model, eval_dataset, batch_size): |
|
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) |
|
|
|
training_args = TrainingArguments( |
|
output_dir="./tmp_trainer", |
|
per_device_eval_batch_size=batch_size, |
|
fp16=True, |
|
report_to="none" |
|
) |
|
|
|
trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset, data_collator=data_collator) |
|
eval_results = trainer.evaluate() |
|
print("eval_results:", eval_results) |
|
loss = eval_results['eval_loss'] |
|
perplexity = torch.exp(torch.tensor(loss)).item() |
|
return perplexity |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Calculate perplexity on test dataset.") |
|
|
|
parser.add_argument('perturbation', |
|
type=str, |
|
default='reverse_full', |
|
nargs='?', |
|
help='Type of perturbation to use.') |
|
parser.add_argument('train_set', |
|
type=str, |
|
default='test', |
|
nargs='?', |
|
help='Dataset size for training.') |
|
parser.add_argument('batch_size', |
|
type=int, |
|
default=4, |
|
nargs='?', |
|
help='Batch size for evaluation.') |
|
parser.add_argument('seed', |
|
type=int, |
|
default=0, |
|
nargs='?', |
|
help='Random seed.') |
|
|
|
args = parser.parse_args() |
|
|
|
dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}" |
|
dataset = load_dataset('../train/babylm_dataset_test.py', name=dataset_name, trust_remote_code=True) |
|
test_dataset = dataset['test'] |
|
print(test_dataset) |
|
|
|
checkpoint_path = f'../train/checkpoints/{MODEL_NAME_SAVE}/babylm_{args.perturbation}_10M_seed0/runs/checkpoint-1200' |
|
|
|
rng = default_rng(args.seed) |
|
indices = rng.choice(len(test_dataset), FILE_SAMPLE_SIZE, replace=False) |
|
sampled_test_dataset = test_dataset.select(indices) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) |
|
model = AutoModelForCausalLM.from_pretrained(checkpoint_path) |
|
|
|
model.eval() |
|
if torch.cuda.is_available(): |
|
model.to('cuda') |
|
|
|
def tokenize_function(examples): |
|
return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=1024) |
|
|
|
tokenized_test = sampled_test_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) |
|
|
|
|
|
perplexity = get_perplexities(model, tokenized_test, 1) |
|
print(f"Perplexity on test set: {perplexity}") |