|
"""Basic Inferencer.""" |
|
import json |
|
import os |
|
from pathlib import Path |
|
from typing import List, Optional |
|
|
|
import numpy as np |
|
from mmengine.dist import is_main_process |
|
from torch.utils.data import DataLoader |
|
|
|
from ..icl_prompt_template import PromptTemplate |
|
from ..icl_retriever import BaseRetriever |
|
|
|
|
|
class BaseInferencer: |
|
"""Base Inferencer class for all evaluation Inferencer. |
|
|
|
Attributes: |
|
model (:obj:`BaseModel`, optional): The module to inference. |
|
max_model_token_num (:obj:`int`, optional): Maximum number of |
|
tokenized words allowed by the LM. |
|
batch_size (:obj:`int`, optional): Batch size for the |
|
:obj:`DataLoader`. |
|
output_json_filepath (:obj:`str`, optional): File path for output |
|
`JSON` file. |
|
output_json_filename (:obj:`str`, optional): File name for output |
|
`JSON` file. |
|
""" |
|
model = None |
|
|
|
def __init__( |
|
self, |
|
model, |
|
max_seq_len: Optional[int] = None, |
|
batch_size: Optional[int] = 1, |
|
output_json_filepath: Optional[str] = './icl_inference_output', |
|
output_json_filename: Optional[str] = 'predictions', |
|
fix_id_list: Optional[List[int]] = None, |
|
**kwargs, |
|
) -> None: |
|
|
|
if fix_id_list: |
|
raise ValueError('Passing fix_id_list to Inferencer is no longer ' |
|
'allowed. Please pass it to FixKRetriever ' |
|
'instead.') |
|
|
|
self.model = model |
|
|
|
self.max_seq_len = max_seq_len |
|
self.batch_size = batch_size |
|
self.output_json_filepath = output_json_filepath |
|
self.output_json_filename = output_json_filename |
|
self.is_main_process = is_main_process() |
|
os.makedirs(self.output_json_filepath, exist_ok=True) |
|
|
|
def inference(self, |
|
retriever: BaseRetriever, |
|
ice_template: Optional[PromptTemplate] = None, |
|
prompt_template: Optional[PromptTemplate] = None, |
|
output_json_filepath: Optional[str] = None, |
|
output_json_filename: Optional[str] = None) -> List: |
|
"""Perform In-Context Inference given a retriever and optional |
|
templates. |
|
|
|
Args: |
|
retriever (:obj:`BaseRetriever`): An instance of a Retriever class |
|
that will be used to retrieve in-context examples |
|
ice_template (:obj:`PromptTemplate`, optional): A template for |
|
generating the in-context examples prompt. Defaults to None. |
|
prompt_template (:obj:`PromptTemplate`, optional): A template for |
|
generating the final prompt. Defaults to None. |
|
output_json_filepath (:obj:`str`, optional): The file path to save |
|
the results as a `JSON` file. Defaults to None. |
|
output_json_filename (:obj:`str`, optional): The file name to save |
|
the results as a `JSON` file. Defaults to None. |
|
|
|
Raises: |
|
NotImplementedError: If the function is not implemented in the |
|
subclass. |
|
|
|
Returns: |
|
:obj:`List:` A list of string, each representing the results of one |
|
inference. |
|
""" |
|
raise NotImplementedError("Method hasn't been implemented yet") |
|
|
|
@staticmethod |
|
def get_dataloader(datalist: List[List], batch_size: int) -> DataLoader: |
|
"""Return a dataloader of the input data list.""" |
|
dataloader = DataLoader(datalist, |
|
batch_size=batch_size, |
|
collate_fn=lambda x: x) |
|
return dataloader |
|
|
|
|
|
def dump_results_dict(results_dict, filename): |
|
with open(filename, 'w', encoding='utf-8') as json_file: |
|
json.dump(results_dict, json_file, indent=4, ensure_ascii=False) |
|
|
|
|
|
class GenInferencerOutputHandler: |
|
origin_prompt_dict = {} |
|
output_dict = {} |
|
prediction_dict = {} |
|
results_dict = {} |
|
|
|
def __init__(self) -> None: |
|
self.results_dict = {} |
|
|
|
def write_to_json(self, save_dir: str, filename: str): |
|
"""Dump the result to a json file.""" |
|
dump_results_dict(self.results_dict, Path(save_dir) / filename) |
|
|
|
def save_results(self, origin_prompt, prediction, idx, gold=None): |
|
self.results_dict[str(idx)] = { |
|
'origin_prompt': origin_prompt, |
|
'prediction': prediction, |
|
} |
|
if gold: |
|
self.results_dict[str(idx)]['gold'] = gold |
|
|
|
|
|
class PPLInferencerOutputHandler: |
|
results_dict = {} |
|
|
|
def __init__(self) -> None: |
|
self.results_dict = {} |
|
|
|
def write_to_json(self, save_dir: str, filename: str): |
|
"""Dump the result to a json file.""" |
|
dump_results_dict(self.results_dict, Path(save_dir) / filename) |
|
|
|
def save_ice(self, ice): |
|
for idx, example in enumerate(ice): |
|
if str(idx) not in self.results_dict.keys(): |
|
self.results_dict[str(idx)] = {} |
|
self.results_dict[str(idx)]['in-context examples'] = example |
|
|
|
def save_predictions(self, predictions): |
|
for idx, prediction in enumerate(predictions): |
|
if str(idx) not in self.results_dict.keys(): |
|
self.results_dict[str(idx)] = {} |
|
self.results_dict[str(idx)]['prediction'] = prediction |
|
|
|
def save_prompt_and_ppl(self, label, input, prompt, ppl, idx): |
|
if str(idx) not in self.results_dict.keys(): |
|
self.results_dict[str(idx)] = {} |
|
if 'label: ' + str(label) not in self.results_dict[str(idx)].keys(): |
|
self.results_dict[str(idx)]['label: ' + str(label)] = {} |
|
self.results_dict[str(idx)]['label: ' + |
|
str(label)]['testing input'] = input |
|
self.results_dict[str(idx)]['label: ' + str(label)]['prompt'] = prompt |
|
self.results_dict[str(idx)]['label: ' + str(label)]['PPL'] = ppl |
|
|
|
def save_golds(self, golds): |
|
for idx, gold in enumerate(golds): |
|
if str(idx) not in self.results_dict.keys(): |
|
self.results_dict[str(idx)] = {} |
|
self.results_dict[str(idx)]['gold'] = gold |
|
|
|
|
|
class CLPInferencerOutputHandler: |
|
results_dict = {} |
|
|
|
def __init__(self) -> None: |
|
self.results_dict = {} |
|
|
|
def write_to_json(self, save_dir: str, filename: str): |
|
"""Dump the result to a json file.""" |
|
dump_results_dict(self.results_dict, Path(save_dir) / filename) |
|
|
|
def save_ice(self, ice): |
|
for idx, example in enumerate(ice): |
|
if str(idx) not in self.results_dict.keys(): |
|
self.results_dict[str(idx)] = {} |
|
self.results_dict[str(idx)]['in-context examples'] = example |
|
|
|
def save_prompt_and_condprob(self, |
|
input, |
|
prompt, |
|
cond_prob, |
|
idx, |
|
choices, |
|
gold=None): |
|
if str(idx) not in self.results_dict.keys(): |
|
self.results_dict[str(idx)] = {} |
|
|
|
|
|
self.results_dict[str(idx)]['testing input'] = input |
|
self.results_dict[str(idx)]['prompt'] = prompt |
|
|
|
self.results_dict[str(idx)]['choices'] = choices |
|
|
|
self.results_dict[str(idx)]['prediction'] = cond_prob |
|
|
|
self.results_dict[str(idx)]['pred_label'] = int(np.argmax(cond_prob)) |
|
self.results_dict[str(idx)]['gold'] = gold |
|
|