stealth-edits / util /perplexity.py
qinghuazhou
Initial commit
85e172b
import os
import sys
import copy
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from . import utils
def perplexity_from_logits(am_gen_logits, om_gen_logits):
""" Calculate perplexity from two sets of logits
"""
if len(om_gen_logits.squeeze().shape)>1:
om_gen_logits = torch.argmax(om_gen_logits.squeeze(), dim=-1)
# load loss objects
m = nn.LogSoftmax(dim=1)
log_probs = torch.gather(
m(am_gen_logits.float()), 1, om_gen_logits[:,None])[0]
return torch.exp(-1 / om_gen_logits.size(0) * log_probs.sum()).item()
def set_perplexity_from_logits(am_set, om_set, prompt_lens):
""" Calculate perplexity from two sets of logits (for a set of samples)
"""
perplexities = np.zeros(len(om_set))
for i in range(len(om_set)):
perplexities[i] = perplexity_from_logits(
am_set[i][prompt_lens[i]:],
om_set[i][prompt_lens[i]:]
)
return perplexities
def generation_ppl(
model: AutoModelForCausalLM,
tok: AutoTokenizer,
prompts: List[str],
tokens_true: torch.Tensor = None,
token_window: int = 30,
batch_size: int = 32,
verbose: bool = False
):
""" Run generation and calculate perplexity
"""
from . import generate
texts = []
preds = []
perplexity = []
if len(prompts)==1: prompts = prompts*2
# find number of batches
num_batches = int(np.ceil(len(prompts) / batch_size))
prompt_lens = [
len(tok.encode(p)) for p in prompts
]
prompt_mask = np.array(prompt_lens)<(token_window-1)
if np.sum(prompt_mask)!=len(prompts):
print('Removed prompts with length > token window')
prompts = list(np.array(prompts)[prompt_mask])
prompt_lens = list(np.array(prompt_lens)[prompt_mask])
for i in tqdm(range(num_batches), disable=(not verbose)):
# run generation
gen_texts, gen_logits = generate.generate_fast(
model,
tok,
prompts = prompts[i*batch_size:(i+1)*batch_size],
n_gen_per_prompt = 1,
top_k = 1,
max_out_len = token_window,
return_logits = True,
)
pred_tokens = torch.argmax(gen_logits.squeeze(), dim=-1)
# get true tokens
if tokens_true is None:
subset_tokens_true = pred_tokens
else:
subset_tokens_true = tokens_true[i*batch_size:(i+1)*batch_size]
if type(subset_tokens_true) == np.ndarray:
subset_tokens_true = torch.from_numpy(subset_tokens_true)
# calculate perplexity
ppl = set_perplexity_from_logits(
gen_logits, subset_tokens_true, prompt_lens[i*batch_size:(i+1)*batch_size])
texts = texts + gen_texts
preds.append(pred_tokens.numpy())
perplexity.append(ppl)
texts = np.array(texts)
preds = np.concatenate(preds)
perplexity = np.concatenate(perplexity)
return texts, preds, perplexity
def cache_ppl(
model,
tok,
dataset,
cache_ppl_file,
token_window = 50,
batch_size = 64,
static_context = '',
selection = None,
reverse_selection = False,
verbose = True
):
""" Function to load or cache perplexity measures
"""
if os.path.exists(cache_ppl_file):
print('Loaded cached perplexity file: ', cache_ppl_file)
cache_ppl_contents = utils.loadpickle(cache_ppl_file)
raw_case_ids = cache_ppl_contents['case_ids']
else:
# find raw requests and case_ids
raw_ds, _, _ = utils.load_dataset(tok, ds_name=dataset)
raw_requests = utils.extract_requests(raw_ds)
raw_case_ids = np.array([r['case_id'] for r in raw_requests])
print('Running perplexity evaluation for original model and prompts...')
texts, preds, ppl_values = generation_ppl(
model,
tok,
prompts = [static_context + r['prompt'].format(r['subject']) for r in raw_requests],
tokens_true = None,
token_window = token_window,
batch_size = batch_size,
verbose = verbose
)
cache_ppl_contents = {
'texts': texts,
'preds': preds,
'requests': raw_requests,
'perplexity': ppl_values,
'case_ids': raw_case_ids,
'token_window': token_window,
'batch_size': batch_size,
'static_context': static_context
}
utils.assure_path_exists(os.path.dirname(cache_ppl_file))
utils.savepickle(cache_ppl_file, cache_ppl_contents)
print('Saved perplexity cache file: ', cache_ppl_file)
# filter cache_ppl_contents for selected samples
if selection is not None:
# load json file containing a dict with key case_ids containing a list of selected samples
select_case_ids = utils.loadjson(selection)['case_ids']
# boolean mask for selected samples w.r.t. all samples in the subjects pickle
matching = utils.generate_mask(raw_case_ids, np.array(select_case_ids))
if reverse_selection: matching = ~matching
# filter cache_ppl_contents for selected samples
cache_ppl_contents = utils.filter_for_selection(cache_ppl_contents, matching)
return cache_ppl_contents