Spaces:
Sleeping
Sleeping
from copy import copy | |
from typing import Dict, List, Optional, Tuple, Union | |
class LMTemplateParser: | |
"""Intermidate prompt template parser, specifically for language models. | |
Args: | |
meta_template (list of dict, optional): The meta template for the | |
model. | |
""" | |
def __init__(self, meta_template: Optional[List[Dict]] = None): | |
self.meta_template = meta_template | |
if meta_template: | |
assert isinstance(meta_template, list) | |
self.roles: Dict[str, dict] = dict() # maps role name to config | |
for item in meta_template: | |
assert isinstance(item, dict) | |
assert item['role'] not in self.roles, \ | |
'role in meta prompt must be unique!' | |
self.roles[item['role']] = item.copy() | |
def __call__(self, dialog) -> str: | |
"""Parse a prompt template, and wrap it with meta template if | |
applicable. | |
Args: | |
dialog (List[str or PromptList]): A prompt | |
template (potentially before being wrapped by meta template). | |
Returns: | |
str: The final string. | |
""" | |
assert isinstance(dialog, (str, list)) | |
if isinstance(dialog, str): | |
return dialog | |
if self.meta_template: | |
prompt = '' | |
for index, item in enumerate(dialog): | |
if isinstance(item, str): | |
prompt += item | |
else: | |
new_str = self._prompt2str(item, index == len(dialog) - 1) | |
prompt += new_str | |
else: | |
# in case the model does not have any meta template | |
prompt = '' | |
last_sep = '' | |
for item in dialog: | |
if isinstance(item, str): | |
if item: | |
prompt += last_sep + item | |
elif item.get('content', ''): | |
prompt += last_sep + item.get('prompt', '') | |
last_sep = '\n' | |
return prompt | |
def _format_begin(self, role_cfg, message): | |
name = message.get('name', None) | |
if name is not None: | |
begin = role_cfg['begin'].get('with_name', '') | |
if name in role_cfg['begin'].get('name', {}): | |
begin = begin.format(name=role_cfg['begin']['name'][name]) | |
else: | |
begin = begin.format(name=name) | |
else: | |
if isinstance(role_cfg.get('begin', ''), str): | |
begin = role_cfg.get('begin', '') | |
elif isinstance(role_cfg['begin'], dict): | |
begin = role_cfg['begin'].get('without_name', '') | |
return begin | |
def _prompt2str(self, | |
prompt: Union[str, Dict], | |
last: bool = False) -> Tuple[str, bool]: | |
if isinstance(prompt, str): | |
return prompt | |
merged_prompt = self.roles.get(prompt['role']) | |
if merged_prompt.get('fallback_role'): | |
merged_prompt = self.roles.get(merged_prompt['fallback_role']) | |
begin = self._format_begin(merged_prompt, prompt) | |
res = begin | |
if last and merged_prompt.get('generate', False): | |
res += prompt.get('content', '') | |
return res | |
res += prompt.get('content', '') + merged_prompt.get('end', '') | |
if last and merged_prompt['role'] != 'assistant': | |
res += self._format_begin(self.roles['assistant'], {}) | |
return res | |
return res | |
class BaseLLM: | |
"""Base class for model wrapper. | |
Args: | |
path (str): The path to the model. | |
max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults | |
to 512. | |
tokenizer_only (bool): If True, only the tokenizer will be initialized. | |
Defaults to False. | |
meta_template (list of dict, optional): The model's meta prompt | |
template if needed, in case the requirement of injecting or | |
wrapping of any meta instructions. | |
""" | |
def __init__(self, | |
path: str, | |
tokenizer_only: bool = False, | |
template_parser: 'LMTemplateParser' = LMTemplateParser, | |
meta_template: Optional[List[Dict]] = None, | |
*, | |
max_new_tokens: int = 512, | |
top_p: float = 0.8, | |
top_k: float = 40, | |
temperature: float = 0.8, | |
repetition_penalty: float = 1.0, | |
stop_words: Union[List[str], str] = None): | |
self.path = path | |
self.tokenizer_only = tokenizer_only | |
# meta template | |
self.template_parser = template_parser(meta_template) | |
self.eos_token_id = None | |
if meta_template and 'eos_token_id' in meta_template: | |
self.eos_token_id = meta_template['eos_token_id'] | |
if isinstance(stop_words, str): | |
stop_words = [stop_words] | |
self.gen_params = dict( | |
max_new_tokens=max_new_tokens, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
stop_words=stop_words) | |
def generate(self, inputs: Union[str, List[str]], **gen_params) -> str: | |
"""Generate results given a str (or list of) inputs. | |
Args: | |
inputs (Union[str, List[str]]): | |
gen_params (dict): The input params for generation. | |
Returns: | |
Union[str, List[str]]: A (list of) generated strings. | |
eg. | |
batched = True | |
if isinstance(inputs, str): | |
inputs = [inputs] | |
batched = False | |
response = [''] | |
if batched: | |
return response | |
return response[0] | |
""" | |
raise NotImplementedError | |
def stream_generate(self, inputs: str, **gen_params) -> List[str]: | |
"""Generate results as streaming given a str inputs. | |
Args: | |
inputs (str): | |
gen_params (dict): The input params for generation. | |
Returns: | |
str: A generated string. | |
""" | |
raise NotImplementedError | |
def chat(self, | |
inputs: Union[List[dict], List[List[dict]]], | |
session_ids: Union[int, List[int]] = None, | |
**gen_params): | |
"""Generate completion from a list of templates. | |
Args: | |
inputs (Union[List[dict], List[List[dict]]]): | |
gen_params (dict): The input params for generation. | |
Returns: | |
""" | |
if isinstance(inputs[0], list): | |
_inputs = list() | |
for msg in inputs: | |
_inputs.append(self.template_parser(msg)) | |
else: | |
_inputs = self.template_parser(inputs) | |
return self.generate(_inputs, **gen_params) | |
def stream_chat(self, inputs: List[dict], **gen_params): | |
"""Generate results as streaming given a list of templates. | |
Args: | |
inputs (Union[List[dict]): | |
gen_params (dict): The input params for generation. | |
Returns: | |
""" | |
raise NotImplementedError | |
def tokenize(self, prompts: Union[str, List[str], List[dict], | |
List[List[dict]]]): | |
"""Tokenize the input prompts. | |
Args: | |
prompts(str | List[str]): user's prompt, or a batch prompts | |
Returns: | |
Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token | |
ids, ids' length and requested output length | |
""" | |
raise NotImplementedError | |
def update_gen_params(self, **kwargs): | |
gen_params = copy(self.gen_params) | |
gen_params.update(kwargs) | |
return gen_params | |
class AsyncLLMMixin: | |
async def generate(self, | |
inputs: Union[str, List[str]], | |
session_ids: Union[int, List[int]] = None, | |
**gen_params) -> str: | |
"""Generate results given a str (or list of) inputs. | |
Args: | |
inputs (Union[str, List[str]]): | |
gen_params (dict): The input params for generation. | |
Returns: | |
Union[str, List[str]]: A (list of) generated strings. | |
eg. | |
batched = True | |
if isinstance(inputs, str): | |
inputs = [inputs] | |
batched = False | |
response = [''] | |
if batched: | |
return response | |
return response[0] | |
""" | |
raise NotImplementedError | |
async def stream_generate(self, inputs: str, **gen_params) -> List[str]: | |
"""Generate results as streaming given a str inputs. | |
Args: | |
inputs (str): | |
gen_params (dict): The input params for generation. | |
Returns: | |
str: A generated string. | |
""" | |
raise NotImplementedError | |
async def chat(self, | |
inputs: Union[List[dict], List[List[dict]]], | |
session_ids: Union[int, List[int]] = None, | |
**gen_params): | |
"""Generate completion from a list of templates. | |
Args: | |
inputs (Union[List[dict], List[List[dict]]]): | |
gen_params (dict): The input params for generation. | |
Returns: | |
""" | |
if isinstance(inputs[0], list): | |
_inputs = list() | |
for msg in inputs: | |
_inputs.append(self.template_parser(msg)) | |
else: | |
_inputs = self.template_parser(inputs) | |
return await self.generate(_inputs, session_ids, **gen_params) | |
async def stream_chat(self, inputs: List[dict], **gen_params): | |
"""Generate results as streaming given a list of templates. | |
Args: | |
inputs (Union[List[dict]): | |
gen_params (dict): The input params for generation. | |
Returns: | |
""" | |
raise NotImplementedError | |
async def tokenize(self, prompts: Union[str, List[str], List[dict], | |
List[List[dict]]]): | |
"""Tokenize the input prompts. | |
Args: | |
prompts(str | List[str]): user's prompt, or a batch prompts | |
Returns: | |
Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token | |
ids, ids' length and requested output length | |
""" | |
raise NotImplementedError | |
class AsyncBaseLLM(AsyncLLMMixin, BaseLLM): | |
pass | |