| import re |
| import sys |
| import threading |
| import time |
| import warnings |
| from abc import abstractmethod |
| from copy import deepcopy |
| from queue import Queue |
| from time import sleep |
| from typing import Dict, List, Optional, Tuple, Union |
|
|
| from opencompass.utils import get_logger |
| from opencompass.utils.prompt import PromptList |
|
|
| from .base import BaseModel |
|
|
| PromptType = Union[PromptList, str] |
|
|
|
|
| class BaseAPIModel(BaseModel): |
| """Base class for API model wrapper. |
| |
| Args: |
| path (str): The path to the model. |
| query_per_second (int): The maximum queries allowed per second |
| between two consecutive calls of the API. Defaults to 1. |
| retry (int): Number of retires if the API call fails. Defaults to 2. |
| max_seq_len (int): The maximum sequence length of the model. Defaults |
| to 2048. |
| meta_template (Dict, optional): The model's meta prompt |
| template if needed, in case the requirement of injecting or |
| wrapping of any meta instructions. |
| generation_kwargs (Dict, optional): The generation kwargs for the |
| model. Defaults to dict(). |
| """ |
|
|
| is_api: bool = True |
|
|
| def __init__(self, |
| path: str, |
| query_per_second: int = 1, |
| rpm_verbose: bool = False, |
| retry: int = 2, |
| max_seq_len: int = 2048, |
| meta_template: Optional[Dict] = None, |
| generation_kwargs: Dict = dict()): |
| self.path = path |
| self.max_seq_len = max_seq_len |
| self.meta_template = meta_template |
| self.retry = retry |
| self.query_per_second = query_per_second |
| self.token_bucket = TokenBucket(query_per_second, rpm_verbose) |
| self.template_parser = APITemplateParser(meta_template) |
| self.logger = get_logger() |
| self.generation_kwargs = generation_kwargs |
|
|
| @abstractmethod |
| def generate(self, inputs: List[PromptType], |
| max_out_len: int) -> List[str]: |
| """Generate results given a list of inputs. |
| |
| Args: |
| inputs (List[str or PromptList]): A list of strings or PromptDicts. |
| The PromptDict should be organized in OpenCompass' |
| API format. |
| max_out_len (int): The maximum length of the output. |
| |
| Returns: |
| List[str]: A list of generated strings. |
| """ |
| raise NotImplementedError(f'{self.__class__.__name__} does not support' |
| ' gen-based evaluation yet, try ppl-based ' |
| 'instead.') |
|
|
| def flush(self): |
| """Ensure simultaneous emptying of stdout and stderr when concurrent |
| resources are available. |
| |
| When employing multiprocessing with standard I/O redirected to files, |
| it is crucial to clear internal data for examination or prevent log |
| loss in case of system failures." |
| """ |
| if hasattr(self, 'tokens'): |
| sys.stdout.flush() |
| sys.stderr.flush() |
|
|
| def acquire(self): |
| """Acquire concurrent resources if exists. |
| |
| This behavior will fall back to wait with query_per_second if there are |
| no concurrent resources. |
| """ |
| if hasattr(self, 'tokens'): |
| self.tokens.acquire() |
| else: |
| self.wait() |
|
|
| def release(self): |
| """Release concurrent resources if acquired. |
| |
| This behavior will fall back to do nothing if there are no concurrent |
| resources. |
| """ |
| if hasattr(self, 'tokens'): |
| self.tokens.release() |
|
|
| @abstractmethod |
| def get_ppl(self, |
| inputs: List[PromptType], |
| mask_length: Optional[List[int]] = None) -> List[float]: |
| """Get perplexity scores given a list of inputs. |
| |
| Args: |
| inputs (List[str or PromptList]): A list of strings. |
| mask_length (Optional[List[int]]): A list of mask lengths. If |
| provided, the perplexity scores will be calculated with the |
| first mask_length[i] tokens masked out. It's okay to skip |
| its implementation if advanced features in PPLInfernecer is |
| not needed. |
| |
| Returns: |
| List[float]: A list of perplexity scores. |
| """ |
| raise NotImplementedError(f'{self.__class__.__name__} does not support' |
| ' ppl-based evaluation yet, try gen-based ' |
| 'instead.') |
|
|
| def get_token_len(self, prompt: str) -> int: |
| """Get lengths of the tokenized string. Only English and Chinese |
| characters are counted for now. Users are encouraged to override this |
| method if more accurate length is needed. |
| |
| Args: |
| prompt (str): Input string. |
| |
| Returns: |
| int: Length of the input tokens |
| """ |
|
|
| english_parts = re.findall(r'[A-Za-z0-9]+', prompt) |
| chinese_parts = re.findall(r'[\u4e00-\u9FFF]+', prompt) |
|
|
| |
| english_count = sum(len(part.split()) for part in english_parts) |
|
|
| |
| chinese_count = sum(len(part) for part in chinese_parts) |
|
|
| return english_count + chinese_count |
|
|
| def wait(self): |
| """Wait till the next query can be sent. |
| |
| Applicable in both single-thread and multi-thread environments. |
| """ |
| return self.token_bucket.get_token() |
|
|
| def to(self, device): |
| pass |
|
|
|
|
| class APITemplateParser: |
| """Intermidate prompt template parser, specifically for API models. |
| |
| Args: |
| meta_template (Dict): The meta template for the model. |
| """ |
|
|
| def __init__(self, meta_template: Optional[Dict] = None): |
| self.meta_template = meta_template |
| |
| if meta_template: |
| assert 'round' in meta_template, 'round is required in meta' \ |
| ' template' |
| assert isinstance(meta_template['round'], list) |
| keys_to_check = ['round'] |
|
|
| if 'reserved_roles' in meta_template: |
| assert isinstance(meta_template['reserved_roles'], list) |
| keys_to_check.append('reserved_roles') |
|
|
| self.roles: Dict[str, dict] = dict() |
| for meta_key in keys_to_check: |
| for item in meta_template[meta_key]: |
| assert isinstance(item, (str, dict)) |
| if isinstance(item, dict): |
| assert item['role'] not in self.roles, \ |
| 'role in meta prompt must be unique!' |
| self.roles[item['role']] = item.copy() |
|
|
| def parse_template(self, prompt_template: PromptType, |
| mode: str) -> PromptType: |
| """Parse the intermidate prompt template, and wrap it with meta |
| template if applicable. When the meta template is set and the input is |
| a PromptList, the return value will be a PromptList containing the full |
| conversation history. Each item looks like: |
| |
| .. code-block:: python |
| |
| {'role': 'user', 'prompt': '...'}). |
| |
| Args: |
| prompt_template (List[str or PromptList]): An intermidate prompt |
| template (potentially before being wrapped by meta template). |
| mode (str): Parsing mode. Choices are 'ppl' and 'gen'. |
| |
| Returns: |
| List[str or PromptList]: The finalized prompt or a conversation. |
| """ |
| assert isinstance(prompt_template, (str, list, PromptList, tuple)) |
|
|
| if not isinstance(prompt_template, (str, PromptList)): |
| return [self.parse_template(p, mode=mode) for p in prompt_template] |
|
|
| assert mode in ['ppl', 'gen'] |
| if isinstance(prompt_template, str): |
| return prompt_template |
| if self.meta_template: |
|
|
| prompt = PromptList() |
| |
| generate = True |
|
|
| section_stack = [] |
|
|
| for i, item in enumerate(prompt_template): |
| if not generate: |
| break |
| if isinstance(item, str): |
| if item.strip(): |
| |
| warnings.warn('Non-empty string in prompt template ' |
| 'will be ignored in API models.') |
| elif isinstance(item, dict) and 'section' in item: |
| if item['pos'] == 'end': |
| section_name, start_idx = section_stack.pop(-1) |
| assert section_name == item['section'] |
| if section_name in ['round', 'ice']: |
| dialogue = prompt_template[start_idx:i] |
| round_ranges = self._split_rounds( |
| dialogue, self.meta_template['round']) |
| |
| |
| for i in range(len(round_ranges) - 1): |
| start = round_ranges[i] |
| end = round_ranges[i + 1] |
| round_template = dialogue[start:end] |
| role_dict = self._update_role_dict( |
| round_template) |
| api_prompts, generate = self._prompt2api( |
| self.meta_template['round'], |
| role_dict, |
| |
| |
| |
| for_gen=mode == 'gen' |
| and section_name == 'round' |
| and i == len(round_ranges) - 2) |
| prompt += api_prompts |
| elif item['pos'] == 'begin': |
| assert item['section'] in [ |
| 'begin', 'round', 'end', 'ice' |
| ] |
| section_stack.append((item['section'], i + 1)) |
| else: |
| raise ValueError(f'Invalid pos {item["pos"]}') |
| elif section_stack[-1][0] in ['begin', 'end']: |
| role_dict = self._update_role_dict(item) |
| api_prompts, generate = self._prompt2api( |
| item, role_dict, for_gen=mode == 'gen') |
| prompt.append(api_prompts) |
|
|
| |
| new_prompt = PromptList([prompt[0]]) |
| last_role = prompt[0]['role'] |
| for item in prompt[1:]: |
| if item['role'] == last_role: |
| new_prompt[-1]['prompt'] += '\n' + item['prompt'] |
| else: |
| last_role = item['role'] |
| new_prompt.append(item) |
| prompt = new_prompt |
|
|
| else: |
| |
| prompt = '' |
| last_sep = '' |
| for item in prompt_template: |
| if isinstance(item, dict) and set(['section', 'pos']) == set( |
| item.keys()): |
| continue |
| if isinstance(item, str): |
| if item: |
| prompt += last_sep + item |
| elif item.get('prompt', ''): |
| prompt += last_sep + item.get('prompt', '') |
| last_sep = '\n' |
| return prompt |
|
|
| def _update_role_dict(self, prompts: Union[List, str]) -> Dict[str, Dict]: |
| """Update the default role dict with the given prompts.""" |
| role_dict = deepcopy(self.roles) |
| if isinstance(prompts, str): |
| return role_dict |
| elif isinstance(prompts, dict): |
| prompts = [prompts] |
| for prompt in prompts: |
| if isinstance(prompt, dict): |
| role = prompt['role'] |
| if role not in self.roles: |
| role = prompt.get('fallback_role', None) |
| if not role: |
| print(f'{prompt} neither has an appropriate role nor ' |
| 'a fallback role.') |
| role_dict[role].update(prompt) |
| return role_dict |
|
|
| def _split_rounds( |
| self, prompt_template: List[Union[str, Dict]], |
| single_round_template: List[Union[str, Dict]]) -> List[int]: |
| """Split the prompt template into rounds, based on single round |
| template. |
| |
| Return the index ranges of each round. Specifically, |
| prompt_template[res[i]:res[i+1]] represents the i-th round in the |
| template. |
| """ |
| role_idxs = { |
| role_cfg['role']: i |
| for i, role_cfg in enumerate(single_round_template) |
| if not isinstance(role_cfg, str) |
| } |
| last_role_idx = -1 |
| cutoff_idxs = [0] |
| for idx, template in enumerate(prompt_template): |
| if isinstance(template, str): |
| continue |
| role_idx = role_idxs.get(template['role'], None) |
| if role_idx is None: |
| try: |
| role_idx = role_idxs[template['fallback_role']] |
| except KeyError: |
| raise KeyError(f'{template} neither has an appropriate ' |
| 'role nor a fallback role.') |
| if role_idx <= last_role_idx: |
| cutoff_idxs.append(idx) |
| last_role_idx = role_idx |
| cutoff_idxs.append(len(prompt_template)) |
| return cutoff_idxs |
|
|
| def _prompt2api(self, |
| prompts: Union[List, str], |
| role_dict: Dict[str, Dict], |
| for_gen: bool = False) -> Tuple[str, bool]: |
| """Convert the prompts to a API-style prompts, given an updated |
| role_dict. |
| |
| Args: |
| prompts (Union[List, str]): The prompts to be converted. |
| role_dict (Dict[str, Dict]): The updated role dict. |
| for_gen (bool): If True, the prompts will be converted for |
| generation tasks. The conversion stops before the first |
| role whose "generate" is set to True. |
| |
| Returns: |
| Tuple[str, bool]: The converted string, and whether the follow-up |
| conversion should be proceeded. |
| """ |
| cont = True |
| if isinstance(prompts, str): |
| return prompts, cont |
| elif isinstance(prompts, dict): |
| api_role, cont = self._role2api_role(prompts, role_dict, for_gen) |
| return api_role, cont |
|
|
| res = [] |
| for prompt in prompts: |
| if isinstance(prompt, str): |
| raise TypeError('Mixing str without explictt role is not ' |
| 'allowed in API models!') |
| else: |
| api_role, cont = self._role2api_role(prompt, role_dict, |
| for_gen) |
| if api_role: |
| res.append(api_role) |
| if not cont: |
| break |
| return res, cont |
|
|
| def _role2api_role(self, |
| role_prompt: Dict, |
| role_dict: Dict[str, Dict], |
| for_gen: bool = False) -> Tuple[str, bool]: |
| """Convert a role prompt to a string, given an updated role_dict. |
| |
| Args: |
| role_prompt (Dict): The role prompt to be converted. |
| role_dict (Dict[str, Dict]): The updated role dict. |
| for_gen (bool): If True, the prompts will be converted for |
| generation tasks. The conversion stops before the first |
| role whose "generate" is set to True. |
| |
| Returns: |
| Tuple[str, bool]: The converted string, and whether the follow-up |
| conversion should be proceeded. |
| """ |
| merged_prompt = role_dict.get( |
| role_prompt['role'], |
| role_dict.get(role_prompt.get('fallback_role'))) |
| |
| if for_gen and merged_prompt.get('generate', False): |
| return None, False |
| res = {} |
| res['role'] = merged_prompt['api_role'] |
| res['prompt'] = merged_prompt.get('begin', '') |
| res['prompt'] += merged_prompt.get('prompt', '') |
| res['prompt'] += merged_prompt.get('end', '') |
| return res, True |
|
|
|
|
| class TokenBucket: |
| """A token bucket for rate limiting. |
| |
| Args: |
| query_per_second (float): The rate of the token bucket. |
| """ |
|
|
| def __init__(self, rate, verbose=False): |
| self._rate = rate |
| self._tokens = threading.Semaphore(0) |
| self.started = False |
| self._request_queue = Queue() |
| self.logger = get_logger() |
| self.verbose = verbose |
|
|
| def _add_tokens(self): |
| """Add tokens to the bucket.""" |
| while True: |
| if self._tokens._value < self._rate: |
| self._tokens.release() |
| sleep(1 / self._rate) |
|
|
| def get_token(self): |
| """Get a token from the bucket.""" |
| if not self.started: |
| self.started = True |
| threading.Thread(target=self._add_tokens, daemon=True).start() |
| self._tokens.acquire() |
| if self.verbose: |
| cur_time = time.time() |
| while not self._request_queue.empty(): |
| if cur_time - self._request_queue.queue[0] > 60: |
| self._request_queue.get() |
| else: |
| break |
| self._request_queue.put(cur_time) |
| self.logger.info(f'Current RPM {self._request_queue.qsize()}.') |
|
|