File size: 3,091 Bytes
94011a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import sys
import argparse
import os
sys.path.append("..")
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
from numpy.random import default_rng
os.environ["TOKENIZERS_PARALLELISM"] = "false"

MODEL_NAME_SAVE = "Llama-3.2-3B"
FILE_SAMPLE_SIZE = 200

def get_perplexities(model, eval_dataset, batch_size):
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    training_args = TrainingArguments(
        output_dir="./tmp_trainer",
        per_device_eval_batch_size=batch_size,
        fp16=True,
        report_to="none"
    )

    trainer = Trainer(model=model, args=training_args, eval_dataset=eval_dataset, data_collator=data_collator)
    eval_results = trainer.evaluate()
    print("eval_results:", eval_results)  
    loss = eval_results['eval_loss']
    perplexity = torch.exp(torch.tensor(loss)).item()
    return perplexity

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Calculate perplexity on test dataset.")

    parser.add_argument('perturbation', 
                        type=str, 
                        default='reverse_full', 
                        nargs='?', 
                        help='Type of perturbation to use.')
    parser.add_argument('train_set', 
                        type=str, 
                        default='test', 
                        nargs='?', 
                        help='Dataset size for training.')
    parser.add_argument('batch_size', 
                        type=int, 
                        default=4, 
                        nargs='?', 
                        help='Batch size for evaluation.')
    parser.add_argument('seed', 
                        type=int, 
                        default=0, 
                        nargs='?', 
                        help='Random seed.')

    args = parser.parse_args()

    dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
    dataset = load_dataset('../train/babylm_dataset_test.py', name=dataset_name, trust_remote_code=True)
    test_dataset = dataset['test']  # Load test dataset
    print(test_dataset)
    
    checkpoint_path = f'../train/checkpoints/{MODEL_NAME_SAVE}/babylm_{args.perturbation}_10M_seed0/runs/checkpoint-1200'
    
    rng = default_rng(args.seed)
    indices = rng.choice(len(test_dataset), FILE_SAMPLE_SIZE, replace=False)
    sampled_test_dataset = test_dataset.select(indices)

    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
    model = AutoModelForCausalLM.from_pretrained(checkpoint_path)
    
    model.eval()
    if torch.cuda.is_available():
        model.to('cuda')

    def tokenize_function(examples):
        return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=1024)

    tokenized_test = sampled_test_dataset.map(tokenize_function, batched=True, remove_columns=["text"])


    perplexity = get_perplexities(model, tokenized_test, 1)
    print(f"Perplexity on test set: {perplexity}")