IE101TW / tools /model_utils /calibrate.py
DeepLearning101's picture
Upload 21 files
45311fe
raw
history blame
10.4 kB
# -*- coding: utf-8 -*-
# @Time    : 2023/3/20 8:02 p.m.
# @Author  : Jianing Wang
# @File    : calibrate.py
import os
import numpy as np
import torch
"""
Use LM to classify label words for calibrating CLS
"""
class CLSCalibrator:
pass
"""
Use Causal LM to generate label words for calibrating CLS
e.g., use gpt2 to generate a label word with in-context prompts, and calibrate for the prediction.
Paper: http://proceedings.mlr.press/v139/zhao21c.html
"""
class CausalCLSCalibrator:
def __init__(self, model, tokenizer) -> None:
self.model = model
self.tokenizer = tokenizer
def calibrate(self, all_label_probs, content_free_examples, label2id, mode="diagonal_W"):
"""Perform calibration for de-biasing and obtain calibrated probability"""
p_cf = self.get_content_free_prediction(content_free_examples, label2id)
num_classes = all_label_probs.shape[1]
if p_cf is None:
# do not calibrate
W = np.identity(num_classes)
b = np.zeros([num_classes, 1])
else:
# calibrate
if mode == "diagonal_W":
W = np.linalg.inv(np.identity(num_classes) * p_cf)
b = np.zeros([num_classes, 1])
elif mode == "identity_W":
W = np.identity(num_classes)
b = -1 * np.expand_dims(p_cf, axis=-1)
else:
assert False
all_calibrate_label_probs = list()
for label_probs in all_label_probs:
label_probs = label_probs / np.sum(label_probs) # normalize to 1
calibrate_label_probs = np.matmul(W, np.expand_dims(label_probs, axis=-1)) + b
all_calibrate_label_probs.append(calibrate_label_probs.squeeze().tolist())
return np.array(all_calibrate_label_probs)
def get_content_free_prediction(self, content_free_examples, label2id: dict):
"""Query model with content free input, return its prediction probability for each label"""
all_p_y = []
for content_free_example in content_free_examples:
content_free_prompt = content_free_example["content_free_prompt"]
p_y = [0] * len(label2id)
for answers, i in label2id.items():
prob = 0
for a in answers:
prob += np.exp(self.get_causal_cls_prediction(content_free_prompt + " " + a, 0, echo=True, num_log_probs=1)['choices'][0]['logprobs']['token_logprobs'][-1])
p_y[i] = prob
all_p_y.append(p_y)
p_y = np.mean(np.array(all_p_y), axis=0)
p_y = p_y / np.sum(p_y) # normalize
return p_y
def get_causal_cls_prediction(self, prompt, l=10, num_log_probs=None, echo=False):
''' This function runs GPT-2 locally but places the outputs into an json that looks just like the one
provided by the OpenAI API. '''
if isinstance(prompt, str):
prompt = [prompt] # the code below assumes a list
input_ids = self.tokenizer.batch_encode_plus(prompt, return_tensors="pt", padding=True)
if l + len(input_ids['input_ids'][0]) > 1020:
m = l + len(input_ids['input_ids'][0]) - 1024
input_ids['input_ids'] = torch.Tensor([input_ids['input_ids'][0][m:].numpy()]).long()
input_ids['attention_mask'] = torch.Tensor([input_ids['attention_mask'][0][m:].numpy()]).long()
# greedily generate l tokens
# print("l=", l)
if l > 0:
# the generate function can handle left padded inputs automatically in HF
# total_sequences is now the input + possible generated output
# print("l + len(input_ids[input_ids][0]=", l + len(input_ids['input_ids'][0]))
total_sequences = self.model.generate(
input_ids=input_ids['input_ids'].to(self.model.device),
attention_mask=input_ids['attention_mask'].to(self.model.device),
max_length=l + len(input_ids['input_ids'][0]),
do_sample=False
)
else:
assert echo == True and l == 0
total_sequences = input_ids['input_ids'].to(self.model.device)
# print("="*50)
# print("total_sequences=", total_sequences) [batch, len+l]
# print("total_sequences.shape=", total_sequences.shape)
# they want the probs of the top tokens
if num_log_probs is not None:
# we are left padding, so we need to adjust the position IDs
attention_mask = (total_sequences != 50256).float()
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
# get the logits for the context and the next l tokens
logits = self.model.forward(input_ids=total_sequences, attention_mask=attention_mask, position_ids=position_ids, return_dict=True).logits.detach().cpu()
if not echo:
# get the top tokens and probs for the generated l tokens
probs = torch.softmax(logits[:,-l-1:], dim=2).cpu()
else:
# get the top tokens and probs for the context and the generated l tokens
probs = torch.softmax(logits, dim=2).cpu()
top_probs, top_tokens = torch.topk(probs, k=num_log_probs)
logprobs = torch.log(probs)
top_log_probs = torch.log(top_probs)
# print("top_log_probs=", top_log_probs)
# print("top_log_probs.shape=", top_log_probs.shape) # [1, 2, 100] [batch, 2, api_num_log_prob]
# create the return value to resemble OpenAI
return_json = {}
choices = []
# print("="*50)
for batch_id in range(len(prompt)):
curr_json = {}
# text is just the optional context and next l tokens
if not echo:
curr_json['text'] = self.tokenizer.decode(total_sequences[batch_id][-l:], skip_special_tokens=True)
else:
curr_json['text'] = self.tokenizer.decode(total_sequences[batch_id], skip_special_tokens=True)
# fill the return json with the top tokens and probs to match the OpenAI return value.
if num_log_probs is not None:
curr_json['logprobs'] = {}
curr_json['logprobs']['top_logprobs'] = []
curr_json['logprobs']['token_logprobs'] = []
curr_json['logprobs']['tokens'] = []
if not echo:
# cutoff the -1 here because the probs are shifted one over for LMs
for current_element_top_log_probs, current_element_top_tokens in zip(top_log_probs[batch_id][:-1], top_tokens[batch_id][:-1]):
# tokens is a list of the top token at each position
curr_json['logprobs']['tokens'].append(self.tokenizer.decode([current_element_top_tokens[0]]))
# token_logprobs is a list of the logprob of the top token at each position
curr_json['logprobs']['token_logprobs'].append(current_element_top_log_probs[0].item())
# top_logprobs is a list of dicts for the top K tokens. with each entry being {'token_name': log_prob}
temp = {}
for log_prob, token in zip(current_element_top_log_probs, current_element_top_tokens):
temp[self.tokenizer.decode(token.item())] = log_prob.item()
curr_json['logprobs']['top_logprobs'].append(temp)
else:
# same as not above but small tweaks
# we add null to the front because for the GPT models, they have null probability for the first token
# (for some reason they don't have an beginning of sentence token)
curr_json['logprobs']['top_logprobs'].append('null')
# cutoff the -1 here because the probs are shifted one over for LMs
for index, (current_element_top_log_probs, current_element_top_tokens) in enumerate(zip(top_log_probs[batch_id][:-1], top_tokens[batch_id][:-1])):
# skip padding tokens
if total_sequences[batch_id][index].item() == 50256:
continue
temp = {}
for log_prob, token in zip(current_element_top_log_probs, current_element_top_tokens):
temp[self.tokenizer.decode(token.item())] = log_prob.item()
curr_json['logprobs']['top_logprobs'].append(temp)
for index in range(len(probs[batch_id])):
curr_json['logprobs']['tokens'].append(self.tokenizer.decode([total_sequences[batch_id][index]]))
curr_json['logprobs']['token_logprobs'].append('null')
for index, log_probs_token_position_j in enumerate(logprobs[batch_id][:-1]):
# probs are left shifted for LMs
curr_json['logprobs']['token_logprobs'].append(log_probs_token_position_j[total_sequences[batch_id][index+1]])
choices.append(curr_json)
# print("curr_json=", curr_json)
'''
e.g.,
num_tokens_to_predict=1
curr_json= {
'text': ' I', # 当前生成的top词
'logprobs': {'top_logprobs': [{' I': -3.4267239570617676, '\n': -3.5073862075805664, ...], # top100词及其socre
'token_logprobs': [-3.4267239570617676], # 当前top词的score
'tokens': [' I']}
}
num_tokens_to_predict=2
curr_json= {
'text': '\nThe', # 如果指定生成两个词,则为两个词
'logprobs': {'top_logprobs': [ # 两个位置对应的预测的score
{'\n': -3.186706304550171, '\xa0': -3.222092390060425, ' We': -6.781067848205566, ...},
{'The': -2.5251243114471436, '"': -2.857935667037964, ...],
'token_logprobs': [-3.186706304550171, -2.5251243114471436], # 生成的词的score
'tokens': ['\n', 'The']}
}
'''
return_json['choices'] = choices
# print("="*50)
# print("return_json=", return_json)
return return_json