"""Self-Consistency Generation Inferencer.""" import os import os.path as osp from typing import List, Optional import mmengine import torch from tqdm import tqdm from opencompass.models.base import BaseModel from ..icl_prompt_template import PromptTemplate from ..icl_retriever import BaseRetriever from ..utils.logging import get_logger from .icl_base_inferencer import BaseInferencer, GenInferencerOutputHandler logger = get_logger(__name__) class SCInferencer(BaseInferencer): """Self-Consistency Inferencer class to evaluate by multiple generations. Attributes: model (:obj:`BaseModelWrapper`, optional): The module to inference. max_seq_len (:obj:`int`, optional): Maximum number of tokenized words allowed by the LM. batch_size (:obj:`int`, optional): Batch size for the :obj:`DataLoader`. output_json_filepath (:obj:`str`, optional): File path for output `JSON` file. output_json_filename (:obj:`str`, optional): File name for output `JSON` file. gen_field_replace_token (:obj:`str`, optional): Used to replace the generation field token when generating prompts. save_every (:obj:`int`, optional): Save intermediate results every `save_every` iters. Defaults to 1. generation_kwargs (:obj:`Dict`, optional): Parameters for the :obj:`model.generate()` method. sc_size (:obj:`int`, optional): Sample size for Self-Consistency infer_type (:obj:`str`, optional): Infer CoT type for :obj:`inference()` method. """ def __init__( self, model: BaseModel, max_out_len: int, max_seq_len: Optional[int] = None, batch_size: Optional[int] = 1, gen_field_replace_token: Optional[str] = '', output_json_filepath: Optional[str] = './icl_inference_output', output_json_filename: Optional[str] = 'predictions', save_every: Optional[int] = 1, sc_size: Optional[int] = 1, infer_type: Optional[str] = '', generation_kwargs: dict = {}, **kwargs) -> None: super().__init__( model=model, max_seq_len=max_seq_len, batch_size=batch_size, output_json_filename=output_json_filename, output_json_filepath=output_json_filepath, **kwargs, ) self.gen_field_replace_token = gen_field_replace_token self.generation_kwargs = generation_kwargs self.max_out_len = max_out_len self.sc_size = sc_size if self.model.is_api and save_every is None: save_every = 1 self.save_every = save_every def inference(self, retriever: BaseRetriever, ice_template: Optional[PromptTemplate] = None, prompt_template: Optional[PromptTemplate] = None, output_json_filepath: Optional[str] = None, output_json_filename: Optional[str] = None) -> List: # 1. Preparation for output logs output_handler = GenInferencerOutputHandler() if output_json_filepath is None: output_json_filepath = self.output_json_filepath if output_json_filename is None: output_json_filename = self.output_json_filename # 2. Get results of retrieval process ice_idx_list = retriever.retrieve() # 3. Generate prompts for testing input prompt_list = self.get_generation_prompt_list_from_retriever_indices( ice_idx_list, retriever, self.gen_field_replace_token, max_seq_len=self.max_seq_len, ice_template=ice_template, prompt_template=prompt_template) # 3.1 Fetch and zip prompt & gold answer if output column exists ds_reader = retriever.dataset_reader if ds_reader.output_column: gold_ans = ds_reader.dataset['test'][ds_reader.output_column] prompt_list = list(zip(prompt_list, gold_ans)) # Create tmp json file for saving intermediate results and future # resuming index = 0 tmp_json_filepath = os.path.join(output_json_filepath, 'tmp_' + output_json_filename) if osp.exists(tmp_json_filepath): # TODO: move resume to output handler tmp_result_dict = mmengine.load(tmp_json_filepath) output_handler.results_dict = tmp_result_dict index = len(tmp_result_dict) # 4. Wrap prompts with Dataloader dataloader = self.get_dataloader(prompt_list[index:], self.batch_size) # 5. Inference for prompts in each batch logger.info('Starting inference process...') for datum in tqdm(dataloader, disable=not self.is_main_process): if ds_reader.output_column: entry, golds = list(zip(*datum)) else: entry = datum golds = [None for _ in range(len(entry))] # TODO: add more types of CoT method # 5-1. Inference sc_size times with local model with torch.no_grad(): parsed_entries = self.model.parse_template(entry, mode='gen') sc_results = [] for _ in range(self.sc_size): results = self.model.generate_from_template( entry, max_out_len=self.max_out_len, **self.generation_kwargs) sc_results.append(results) sc_prediction = list(map(list, zip(*sc_results))) generated = sc_prediction # 5-3. Save current output for prompt, prediction, gold in zip(parsed_entries, generated, golds): output_handler.save_results(prompt, prediction, index, gold=gold) index = index + 1 # 5-4. Save intermediate results if (self.save_every is not None and index % self.save_every == 0 and self.is_main_process): output_handler.write_to_json(output_json_filepath, 'tmp_' + output_json_filename) # 6. Output if self.is_main_process: os.makedirs(output_json_filepath, exist_ok=True) output_handler.write_to_json(output_json_filepath, output_json_filename) if osp.exists(tmp_json_filepath): os.remove(tmp_json_filepath) return [ sample['prediction'] for sample in output_handler.results_dict.values() ] def get_generation_prompt_list_from_retriever_indices( self, ice_idx_list: List[List[int]], retriever: BaseRetriever, gen_field_replace_token: str, max_seq_len: Optional[int] = None, ice_template: Optional[PromptTemplate] = None, prompt_template: Optional[PromptTemplate] = None): prompt_list = [] for idx, ice_idx in enumerate(ice_idx_list): ice = retriever.generate_ice(ice_idx, ice_template=ice_template) prompt = retriever.generate_prompt_for_generate_task( idx, ice, gen_field_replace_token=gen_field_replace_token, ice_template=ice_template, prompt_template=prompt_template) if max_seq_len is not None: prompt_token_num = self.model.get_token_len_from_template( prompt, mode='gen') while len(ice_idx) > 0 and prompt_token_num > max_seq_len: ice_idx = ice_idx[:-1] ice = retriever.generate_ice(ice_idx, ice_template=ice_template) prompt = retriever.generate_prompt_for_generate_task( idx, ice, gen_field_replace_token=gen_field_replace_token, ice_template=ice_template, prompt_template=prompt_template) prompt_token_num = self.model.get_token_len_from_template( prompt, mode='gen') prompt_list.append(prompt) return prompt_list