File size: 5,526 Bytes
85e172b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import os
import sys
import copy
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from . import utils
def perplexity_from_logits(am_gen_logits, om_gen_logits):
""" Calculate perplexity from two sets of logits
"""
if len(om_gen_logits.squeeze().shape)>1:
om_gen_logits = torch.argmax(om_gen_logits.squeeze(), dim=-1)
# load loss objects
m = nn.LogSoftmax(dim=1)
log_probs = torch.gather(
m(am_gen_logits.float()), 1, om_gen_logits[:,None])[0]
return torch.exp(-1 / om_gen_logits.size(0) * log_probs.sum()).item()
def set_perplexity_from_logits(am_set, om_set, prompt_lens):
""" Calculate perplexity from two sets of logits (for a set of samples)
"""
perplexities = np.zeros(len(om_set))
for i in range(len(om_set)):
perplexities[i] = perplexity_from_logits(
am_set[i][prompt_lens[i]:],
om_set[i][prompt_lens[i]:]
)
return perplexities
def generation_ppl(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
prompts: List[str],
tokens_true: torch.Tensor = None,
token_window: int = 30,
batch_size: int = 32,
verbose: bool = False
):
""" Run generation and calculate perplexity
"""
from . import generate
texts = []
preds = []
perplexity = []
if len(prompts)==1: prompts = prompts*2
# find number of batches
num_batches = int(np.ceil(len(prompts) / batch_size))
prompt_lens = [
len(tok.encode(p)) for p in prompts
]
prompt_mask = np.array(prompt_lens)<(token_window-1)
if np.sum(prompt_mask)!=len(prompts):
print('Removed prompts with length > token window')
prompts = list(np.array(prompts)[prompt_mask])
prompt_lens = list(np.array(prompt_lens)[prompt_mask])
for i in tqdm(range(num_batches), disable=(not verbose)):
# run generation
gen_texts, gen_logits = generate.generate_fast(
model,
tok,
prompts = prompts[i*batch_size:(i+1)*batch_size],
n_gen_per_prompt = 1,
top_k = 1,
max_out_len = token_window,
return_logits = True,
)
pred_tokens = torch.argmax(gen_logits.squeeze(), dim=-1)
# get true tokens
if tokens_true is None:
subset_tokens_true = pred_tokens
else:
subset_tokens_true = tokens_true[i*batch_size:(i+1)*batch_size]
if type(subset_tokens_true) == np.ndarray:
subset_tokens_true = torch.from_numpy(subset_tokens_true)
# calculate perplexity
ppl = set_perplexity_from_logits(
gen_logits, subset_tokens_true, prompt_lens[i*batch_size:(i+1)*batch_size])
texts = texts + gen_texts
preds.append(pred_tokens.numpy())
perplexity.append(ppl)
texts = np.array(texts)
preds = np.concatenate(preds)
perplexity = np.concatenate(perplexity)
return texts, preds, perplexity
def cache_ppl(
model,
tok,
dataset,
cache_ppl_file,
token_window = 50,
batch_size = 64,
static_context = '',
selection = None,
reverse_selection = False,
verbose = True
):
""" Function to load or cache perplexity measures
"""
if os.path.exists(cache_ppl_file):
print('Loaded cached perplexity file: ', cache_ppl_file)
cache_ppl_contents = utils.loadpickle(cache_ppl_file)
raw_case_ids = cache_ppl_contents['case_ids']
else:
# find raw requests and case_ids
raw_ds, _, _ = utils.load_dataset(tok, ds_name=dataset)
raw_requests = utils.extract_requests(raw_ds)
raw_case_ids = np.array([r['case_id'] for r in raw_requests])
print('Running perplexity evaluation for original model and prompts...')
texts, preds, ppl_values = generation_ppl(
model,
tok,
prompts = [static_context + r['prompt'].format(r['subject']) for r in raw_requests],
tokens_true = None,
token_window = token_window,
batch_size = batch_size,
verbose = verbose
)
cache_ppl_contents = {
'texts': texts,
'preds': preds,
'requests': raw_requests,
'perplexity': ppl_values,
'case_ids': raw_case_ids,
'token_window': token_window,
'batch_size': batch_size,
'static_context': static_context
}
utils.assure_path_exists(os.path.dirname(cache_ppl_file))
utils.savepickle(cache_ppl_file, cache_ppl_contents)
print('Saved perplexity cache file: ', cache_ppl_file)
# filter cache_ppl_contents for selected samples
if selection is not None:
# load json file containing a dict with key case_ids containing a list of selected samples
select_case_ids = utils.loadjson(selection)['case_ids']
# boolean mask for selected samples w.r.t. all samples in the subjects pickle
matching = utils.generate_mask(raw_case_ids, np.array(select_case_ids))
if reverse_selection: matching = ~matching
# filter cache_ppl_contents for selected samples
cache_ppl_contents = utils.filter_for_selection(cache_ppl_contents, matching)
return cache_ppl_contents
|