|
"""CLP Inferencer.""" |
|
|
|
import itertools |
|
import os |
|
from typing import List, Optional |
|
|
|
import torch.nn.functional as F |
|
from tqdm import trange |
|
|
|
from opencompass.models import BaseModel |
|
from opencompass.registry import ICL_INFERENCERS |
|
|
|
from ..icl_prompt_template import PromptTemplate |
|
from ..icl_retriever import BaseRetriever |
|
from ..utils import get_logger |
|
from .icl_base_inferencer import BaseInferencer, CLPInferencerOutputHandler |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
@ICL_INFERENCERS.register_module() |
|
class CLPInferencer(BaseInferencer): |
|
"""Conditional log probability based In-context Learning Inferencer. |
|
|
|
Calculate the log probability of each choices according the logits. |
|
The input is the context with single choice, e.g. Q: xx.\n A: first choice |
|
to this question. |
|
And starting from the first token of this choice, sum up all the log |
|
probabilities of each |
|
tokens from logits. Then, compare each choice with softmax. |
|
|
|
There are two scenarios in this case: |
|
1. Single token choices. Already supported. |
|
2. Muiltple token choices. TODO: More complicated and needs to be added in |
|
the future for specific dataset. |
|
|
|
Attributes: |
|
model (:obj:`BaseModel`, optional): The module to inference. |
|
max_seq_len (:obj:`int`): Maximum number of tokenized words allowed by |
|
the LM. |
|
batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader` |
|
output_json_filepath (:obj:`str`, optional): File path for output |
|
`JSON` file. |
|
output_json_filename (:obj:`str`, optional): File name for output |
|
`JSON` file. |
|
single_token (:obj:`bool`): If ``True``, choices only have one token to |
|
calculate. Defaults to True. Currently only support True. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
model: BaseModel, |
|
max_seq_len: Optional[int] = None, |
|
batch_size: Optional[int] = 1, |
|
output_json_filepath: Optional[str] = './icl_inference_output', |
|
output_json_filename: Optional[str] = 'predictions', |
|
single_token: bool = True, |
|
**kwargs) -> None: |
|
super().__init__( |
|
model=model, |
|
max_seq_len=max_seq_len, |
|
batch_size=batch_size, |
|
output_json_filename=output_json_filename, |
|
output_json_filepath=output_json_filepath, |
|
**kwargs, |
|
) |
|
|
|
|
|
assert single_token, 'Only support single token choice currently.' |
|
self.single_token = single_token |
|
|
|
def inference(self, |
|
retriever: BaseRetriever, |
|
ice_template: Optional[PromptTemplate] = None, |
|
prompt_template: Optional[PromptTemplate] = None, |
|
output_json_filepath: Optional[str] = None, |
|
output_json_filename: Optional[str] = None, |
|
normalizing_str: Optional[str] = None) -> List: |
|
|
|
output_handler = CLPInferencerOutputHandler() |
|
|
|
ice = [] |
|
|
|
if output_json_filepath is None: |
|
output_json_filepath = self.output_json_filepath |
|
if output_json_filename is None: |
|
output_json_filename = self.output_json_filename |
|
|
|
|
|
|
|
|
|
if self.model.is_api: |
|
|
|
if self.is_main_process: |
|
os.makedirs(output_json_filepath, exist_ok=True) |
|
err_msg = 'API model is not supported for conditional log '\ |
|
'probability inference and skip this exp.' |
|
output_handler.results_dict = {'error': err_msg} |
|
output_handler.write_to_json(output_json_filepath, |
|
output_json_filename) |
|
raise ValueError(err_msg) |
|
|
|
|
|
ice_idx_list = retriever.retrieve() |
|
|
|
|
|
for idx in range(len(ice_idx_list)): |
|
ice.append( |
|
retriever.generate_ice(ice_idx_list[idx], |
|
ice_template=ice_template)) |
|
output_handler.save_ice(ice) |
|
|
|
|
|
if self.single_token: |
|
index = 0 |
|
prompt_list = [] |
|
target_pos = [] |
|
|
|
choices = retriever.test_ds[0]['choices'] |
|
try: |
|
choice_ids = [ |
|
self.model.tokenizer.encode(c, False, False) |
|
for c in choices |
|
] |
|
except ValueError: |
|
choice_ids = [self.model.tokenizer.encode(c) for c in choices] |
|
if self.model.tokenizer.__class__.__name__ == 'ChatGLMTokenizer': |
|
choice_ids = [c[2:] for c in choice_ids] |
|
elif hasattr(self.model.tokenizer, 'add_bos_token'): |
|
if self.model.tokenizer.add_bos_token: |
|
choice_ids = [c[1:] for c in choice_ids] |
|
if self.model.tokenizer.add_eos_token: |
|
choice_ids = [c[:-1] for c in choice_ids] |
|
if isinstance(choice_ids[0], list): |
|
|
|
choice_ids = list(itertools.chain(*choice_ids)) |
|
|
|
get_token_len = self.model.get_token_len |
|
|
|
if hasattr(self.model.tokenizer, 'padding_side'): |
|
|
|
padding_side = self.model.tokenizer.padding_side |
|
else: |
|
|
|
padding_side = 'left' |
|
|
|
|
|
for idx in range(len(ice_idx_list)): |
|
prompt = retriever.generate_prompt_for_generate_task( |
|
idx, |
|
ice[idx], |
|
ice_template=ice_template, |
|
prompt_template=prompt_template) |
|
prompt = self.model.parse_template(prompt, mode='gen') |
|
if self.max_seq_len is not None: |
|
prompt_token_num = get_token_len(prompt) |
|
|
|
while len( |
|
ice_idx_list[idx] |
|
) > 0 and prompt_token_num + 1 > self.max_seq_len: |
|
ice_idx_list[idx] = ice_idx_list[idx][:-1] |
|
ice[idx] = retriever.generate_ice( |
|
ice_idx_list[idx], ice_template=ice_template) |
|
prompt = retriever.generate_prompt_for_generate_task( |
|
idx, |
|
ice[idx], |
|
ice_template=ice_template, |
|
prompt_template=prompt_template) |
|
prompt_token_num = get_token_len(prompt) |
|
prompt_list.append(prompt) |
|
|
|
if self.max_seq_len is not None and \ |
|
prompt_token_num + 1 > self.max_seq_len: |
|
prompt_token_num = self.max_seq_len - 1 |
|
|
|
|
|
if padding_side == 'left': |
|
|
|
target_pos.append(-1) |
|
else: |
|
|
|
target_pos.append(prompt_token_num - 1) |
|
|
|
|
|
ds_reader = retriever.dataset_reader |
|
if ds_reader.output_column: |
|
gold_ans = ds_reader.dataset['test'][ds_reader.output_column] |
|
else: |
|
gold_ans = [None] * len(prompt_list) |
|
|
|
if hasattr(self.model, 'batch_padding'): |
|
|
|
batch_padding = self.model.batch_padding |
|
else: |
|
|
|
batch_padding = False |
|
|
|
logger.info('Calculating conditional log probability for prompts.') |
|
for idx in trange(0, |
|
len(prompt_list), |
|
self.batch_size, |
|
disable=not self.is_main_process): |
|
|
|
sub_prompt_list = prompt_list[idx:idx + self.batch_size] |
|
sub_golds = gold_ans[idx:idx + self.batch_size] |
|
sub_target_pos = target_pos[idx:idx + self.batch_size] |
|
|
|
|
|
if batch_padding and self.batch_size > 1: |
|
sub_res = self._get_cond_prob(sub_prompt_list, |
|
sub_target_pos, choice_ids) |
|
else: |
|
sub_res = [] |
|
for prompt, position in zip(sub_prompt_list, |
|
sub_target_pos): |
|
sub_res.extend( |
|
self._get_cond_prob([prompt], [position], |
|
choice_ids)) |
|
|
|
|
|
for res, prompt, gold in zip(sub_res, sub_prompt_list, |
|
sub_golds): |
|
example_input = prompt.replace(ice[idx], '') |
|
output_handler.save_prompt_and_condprob(example_input, |
|
prompt, |
|
res, |
|
index, |
|
choices, |
|
gold=gold) |
|
index = index + 1 |
|
|
|
|
|
if self.is_main_process: |
|
os.makedirs(output_json_filepath, exist_ok=True) |
|
output_handler.write_to_json(output_json_filepath, |
|
output_json_filename) |
|
|
|
return [ |
|
sample['prediction'] |
|
for sample in output_handler.results_dict.values() |
|
] |
|
|
|
def _get_cond_prob(self, input_texts: List[str], target_pos: List[int], |
|
choice_ids: List[int]): |
|
"""Get the condition probability of next token. |
|
|
|
Args: |
|
input_texts (List[str]): All the input prompt to be tested. |
|
target_pos (List[int]): Target position of next token. |
|
choice_ids (List[int]): Choice ids of target tokens. |
|
""" |
|
if hasattr(self.model, 'generator'): |
|
get_logits = self.model.generator.get_logits |
|
else: |
|
get_logits = self.model.get_logits |
|
|
|
outputs, _ = get_logits(input_texts) |
|
|
|
|
|
|
|
logits = outputs.contiguous().float() |
|
|
|
logits = F.log_softmax(logits, dim=-1) |
|
log_probs = [] |
|
for logit, target_ids in zip(logits, target_pos): |
|
log_probs.append( |
|
F.softmax(logit[target_ids, choice_ids], dim=-1).tolist()) |
|
return log_probs |
|
|