| | |
| | |
| |
|
| | |
| | import sys |
| | sys.path.append("..") |
| |
|
| | from transformers import GPT2LMHeadModel |
| | from gpt2_no_positional_encoding_model import GPT2NoPositionalEncodingLMHeadModel |
| | from utils import CHECKPOINT_READ_PATH, PERTURBATIONS, BABYLM_DATA_PATH, \ |
| | PAREN_MODELS, gpt2_original_tokenizer |
| | from tqdm import tqdm |
| | from glob import glob |
| | from numpy.random import default_rng |
| | import pandas as pd |
| | import torch |
| | import itertools |
| | import argparse |
| | import os |
| |
|
| |
|
| | MAX_TRAINING_STEPS = 3000 |
| | CHECKPOINTS = list(range(100, MAX_TRAINING_STEPS+1, 100)) |
| |
|
| |
|
| | def create_attention_mask(token_lists): |
| | seq_length = max([len(i) for i in token_lists]) |
| | batch_size = len(token_lists) |
| | mask = torch.full((batch_size, seq_length), 0) |
| |
|
| | for i, tokens in enumerate(token_lists): |
| | mask[i, 0:len(tokens)] = 1 |
| |
|
| | return mask |
| |
|
| | def create_input_ids(token_lists, pad_token_id): |
| | padded = zip(*itertools.zip_longest(*token_lists, fillvalue=pad_token_id)) |
| | return torch.tensor(list(padded)) |
| |
|
| |
|
| | def get_perplexities(model, token_lists, pad_token_id, device="cuda"): |
| |
|
| | |
| | input_ids = create_input_ids(token_lists, pad_token_id).to(device) |
| | labels = input_ids.clone() |
| | attention_mask = create_attention_mask(token_lists).to(device) |
| |
|
| | |
| | outputs = model(input_ids=input_ids, labels=labels, |
| | attention_mask=attention_mask) |
| |
|
| | |
| | |
| | shift_logits = outputs.logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | shift_attention_mask = attention_mask[..., 1:].contiguous() |
| |
|
| | |
| | loss_fct = torch.nn.CrossEntropyLoss(reduction='none') |
| |
|
| | |
| | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), |
| | shift_labels.view(-1)) |
| |
|
| | |
| | loss = loss.view(shift_labels.size()) |
| |
|
| | |
| | loss = loss * shift_attention_mask |
| |
|
| | |
| | per_example_loss = loss.sum(dim=1) / shift_attention_mask.sum(dim=1) |
| | return torch.exp(per_example_loss).tolist() |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser( |
| | prog='Edge probing', |
| | description='Edge probing experiments') |
| | parser.add_argument('perturbation_type', |
| | default='all', |
| | const='all', |
| | nargs='?', |
| | choices=PERTURBATIONS.keys(), |
| | help='Perturbation function used to transform BabyLM dataset') |
| | parser.add_argument('test_perturbation_type', |
| | default='all', |
| | const='all', |
| | nargs='?', |
| | choices=PERTURBATIONS.keys(), |
| | help='Perturbation function used to transform test BabyLM dataset') |
| | parser.add_argument('train_set', |
| | default='all', |
| | const='all', |
| | nargs='?', |
| | choices=["100M", "10M"], |
| | help='BabyLM train set') |
| | parser.add_argument('random_seed', type=int, help="Random seed") |
| | parser.add_argument('paren_model', |
| | default='all', |
| | const='all', |
| | nargs='?', |
| | choices=list(PAREN_MODELS.keys()) + ["randinit"], |
| | help='Parenthesis model') |
| | parser.add_argument('-np', '--no_pos_encodings', action='store_true', |
| | help="Train GPT-2 with no positional encodings") |
| |
|
| | |
| | args = parser.parse_args() |
| | no_pos_encodings_underscore = "_no_positional_encodings" if args.no_pos_encodings else "" |
| |
|
| | |
| | model = f"babylm_{args.perturbation_type}_{args.train_set}_{args.paren_model}{no_pos_encodings_underscore}_seed{args.random_seed}" |
| | model_path = f"{CHECKPOINT_READ_PATH}/babylm_{args.perturbation_type}_{args.train_set}_{args.paren_model}{no_pos_encodings_underscore}/{model}/runs/{model}/checkpoint-" |
| |
|
| | |
| | test_files = sorted(glob( |
| | f"{BABYLM_DATA_PATH}/babylm_data_perturbed/babylm_{args.test_perturbation_type}/babylm_test_affected/*")) |
| |
|
| | FILE_SAMPLE_SIZE = 1000 |
| | rng = default_rng(args.random_seed) |
| |
|
| | |
| | print("Sampling BabyLM affected test files to extract surprisals...") |
| | token_sequences = [] |
| | for test_file in test_files: |
| | print(test_file) |
| |
|
| | |
| | f = open(test_file, 'r') |
| | file_token_sequences = [ |
| | [int(s) for s in l.split()] for l in f.readlines()] |
| | sample_indices = rng.choice( |
| | list(range(len(file_token_sequences))), FILE_SAMPLE_SIZE, replace=False) |
| | file_token_sequences = [file_token_sequences[i] |
| | for i in sample_indices] |
| | token_sequences.extend(file_token_sequences) |
| |
|
| | |
| | test_sents = [gpt2_original_tokenizer.decode( |
| | toks) for toks in token_sequences] |
| |
|
| | ppl_df = pd.DataFrame({ |
| | "Sentences": test_sents |
| | }) |
| |
|
| | BATCH_SIZE = 8 |
| | device = "cuda" |
| | for ckpt in CHECKPOINTS: |
| | print(f"Checkpoint: {ckpt}") |
| |
|
| | |
| | if args.no_pos_encodings: |
| | model = GPT2NoPositionalEncodingLMHeadModel.from_pretrained( |
| | model_path + str(ckpt)).to(device) |
| | else: |
| | model = GPT2LMHeadModel.from_pretrained( |
| | model_path + str(ckpt)).to(device) |
| |
|
| | |
| | perplexities = [] |
| | for i in tqdm(range(0, len(token_sequences), BATCH_SIZE)): |
| | batch = token_sequences[i:i+BATCH_SIZE] |
| | ppls = get_perplexities( |
| | model, batch, gpt2_original_tokenizer.eos_token_id) |
| | perplexities.extend(ppls) |
| |
|
| | |
| | ppl_df[f'Perplexities (ckpt {ckpt})'] = perplexities |
| |
|
| | |
| | directory = f"perplexity_results/{args.perturbation_type}_{args.train_set}{no_pos_encodings_underscore}" |
| | if not os.path.exists(directory): |
| | os.makedirs(directory) |
| | file = directory + \ |
| | f"/{args.paren_model}_seed{args.random_seed}_test_{args.test_perturbation_type}.csv" |
| | print(f"Writing results to CSV: {file}") |
| | ppl_df.to_csv(file) |