Lagent / lagent /llms /base_llm.py
Superkingjcj's picture
Upload 111 files
e679d69 verified
raw
history blame
10.4 kB
from copy import copy
from typing import Dict, List, Optional, Tuple, Union
class LMTemplateParser:
"""Intermidate prompt template parser, specifically for language models.
Args:
meta_template (list of dict, optional): The meta template for the
model.
"""
def __init__(self, meta_template: Optional[List[Dict]] = None):
self.meta_template = meta_template
if meta_template:
assert isinstance(meta_template, list)
self.roles: Dict[str, dict] = dict() # maps role name to config
for item in meta_template:
assert isinstance(item, dict)
assert item['role'] not in self.roles, \
'role in meta prompt must be unique!'
self.roles[item['role']] = item.copy()
def __call__(self, dialog) -> str:
"""Parse a prompt template, and wrap it with meta template if
applicable.
Args:
dialog (List[str or PromptList]): A prompt
template (potentially before being wrapped by meta template).
Returns:
str: The final string.
"""
assert isinstance(dialog, (str, list))
if isinstance(dialog, str):
return dialog
if self.meta_template:
prompt = ''
for index, item in enumerate(dialog):
if isinstance(item, str):
prompt += item
else:
new_str = self._prompt2str(item, index == len(dialog) - 1)
prompt += new_str
else:
# in case the model does not have any meta template
prompt = ''
last_sep = ''
for item in dialog:
if isinstance(item, str):
if item:
prompt += last_sep + item
elif item.get('content', ''):
prompt += last_sep + item.get('prompt', '')
last_sep = '\n'
return prompt
def _format_begin(self, role_cfg, message):
name = message.get('name', None)
if name is not None:
begin = role_cfg['begin'].get('with_name', '')
if name in role_cfg['begin'].get('name', {}):
begin = begin.format(name=role_cfg['begin']['name'][name])
else:
begin = begin.format(name=name)
else:
if isinstance(role_cfg.get('begin', ''), str):
begin = role_cfg.get('begin', '')
elif isinstance(role_cfg['begin'], dict):
begin = role_cfg['begin'].get('without_name', '')
return begin
def _prompt2str(self,
prompt: Union[str, Dict],
last: bool = False) -> Tuple[str, bool]:
if isinstance(prompt, str):
return prompt
merged_prompt = self.roles.get(prompt['role'])
if merged_prompt.get('fallback_role'):
merged_prompt = self.roles.get(merged_prompt['fallback_role'])
begin = self._format_begin(merged_prompt, prompt)
res = begin
if last and merged_prompt.get('generate', False):
res += prompt.get('content', '')
return res
res += prompt.get('content', '') + merged_prompt.get('end', '')
if last and merged_prompt['role'] != 'assistant':
res += self._format_begin(self.roles['assistant'], {})
return res
return res
class BaseLLM:
"""Base class for model wrapper.
Args:
path (str): The path to the model.
max_new_tokens (int): Maximum length of output expected to be generated by the model. Defaults
to 512.
tokenizer_only (bool): If True, only the tokenizer will be initialized.
Defaults to False.
meta_template (list of dict, optional): The model's meta prompt
template if needed, in case the requirement of injecting or
wrapping of any meta instructions.
"""
def __init__(self,
path: str,
tokenizer_only: bool = False,
template_parser: 'LMTemplateParser' = LMTemplateParser,
meta_template: Optional[List[Dict]] = None,
*,
max_new_tokens: int = 512,
top_p: float = 0.8,
top_k: float = 40,
temperature: float = 0.8,
repetition_penalty: float = 1.0,
stop_words: Union[List[str], str] = None):
self.path = path
self.tokenizer_only = tokenizer_only
# meta template
self.template_parser = template_parser(meta_template)
self.eos_token_id = None
if meta_template and 'eos_token_id' in meta_template:
self.eos_token_id = meta_template['eos_token_id']
if isinstance(stop_words, str):
stop_words = [stop_words]
self.gen_params = dict(
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
stop_words=stop_words)
def generate(self, inputs: Union[str, List[str]], **gen_params) -> str:
"""Generate results given a str (or list of) inputs.
Args:
inputs (Union[str, List[str]]):
gen_params (dict): The input params for generation.
Returns:
Union[str, List[str]]: A (list of) generated strings.
eg.
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
response = ['']
if batched:
return response
return response[0]
"""
raise NotImplementedError
def stream_generate(self, inputs: str, **gen_params) -> List[str]:
"""Generate results as streaming given a str inputs.
Args:
inputs (str):
gen_params (dict): The input params for generation.
Returns:
str: A generated string.
"""
raise NotImplementedError
def chat(self,
inputs: Union[List[dict], List[List[dict]]],
session_ids: Union[int, List[int]] = None,
**gen_params):
"""Generate completion from a list of templates.
Args:
inputs (Union[List[dict], List[List[dict]]]):
gen_params (dict): The input params for generation.
Returns:
"""
if isinstance(inputs[0], list):
_inputs = list()
for msg in inputs:
_inputs.append(self.template_parser(msg))
else:
_inputs = self.template_parser(inputs)
return self.generate(_inputs, **gen_params)
def stream_chat(self, inputs: List[dict], **gen_params):
"""Generate results as streaming given a list of templates.
Args:
inputs (Union[List[dict]):
gen_params (dict): The input params for generation.
Returns:
"""
raise NotImplementedError
def tokenize(self, prompts: Union[str, List[str], List[dict],
List[List[dict]]]):
"""Tokenize the input prompts.
Args:
prompts(str | List[str]): user's prompt, or a batch prompts
Returns:
Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token
ids, ids' length and requested output length
"""
raise NotImplementedError
def update_gen_params(self, **kwargs):
gen_params = copy(self.gen_params)
gen_params.update(kwargs)
return gen_params
class AsyncLLMMixin:
async def generate(self,
inputs: Union[str, List[str]],
session_ids: Union[int, List[int]] = None,
**gen_params) -> str:
"""Generate results given a str (or list of) inputs.
Args:
inputs (Union[str, List[str]]):
gen_params (dict): The input params for generation.
Returns:
Union[str, List[str]]: A (list of) generated strings.
eg.
batched = True
if isinstance(inputs, str):
inputs = [inputs]
batched = False
response = ['']
if batched:
return response
return response[0]
"""
raise NotImplementedError
async def stream_generate(self, inputs: str, **gen_params) -> List[str]:
"""Generate results as streaming given a str inputs.
Args:
inputs (str):
gen_params (dict): The input params for generation.
Returns:
str: A generated string.
"""
raise NotImplementedError
async def chat(self,
inputs: Union[List[dict], List[List[dict]]],
session_ids: Union[int, List[int]] = None,
**gen_params):
"""Generate completion from a list of templates.
Args:
inputs (Union[List[dict], List[List[dict]]]):
gen_params (dict): The input params for generation.
Returns:
"""
if isinstance(inputs[0], list):
_inputs = list()
for msg in inputs:
_inputs.append(self.template_parser(msg))
else:
_inputs = self.template_parser(inputs)
return await self.generate(_inputs, session_ids, **gen_params)
async def stream_chat(self, inputs: List[dict], **gen_params):
"""Generate results as streaming given a list of templates.
Args:
inputs (Union[List[dict]):
gen_params (dict): The input params for generation.
Returns:
"""
raise NotImplementedError
async def tokenize(self, prompts: Union[str, List[str], List[dict],
List[List[dict]]]):
"""Tokenize the input prompts.
Args:
prompts(str | List[str]): user's prompt, or a batch prompts
Returns:
Tuple(numpy.ndarray, numpy.ndarray, numpy.ndarray): prompt's token
ids, ids' length and requested output length
"""
raise NotImplementedError
class AsyncBaseLLM(AsyncLLMMixin, BaseLLM):
pass