TwT-6's picture
Upload 2667 files
256a159 verified
raw
history blame
4.46 kB
from typing import Dict, List, Optional
from opencompass.models.base import BaseModel
from opencompass.utils import get_logger
try:
from vllm import LLM, SamplingParams
except ImportError:
LLM, SamplingParams = None, None
DEFAULT_MODEL_KWARGS = dict(trust_remote_code=True)
class VLLM(BaseModel):
"""Model Wrapper for VLLM."""
def __init__(
self,
path: str,
max_seq_len: int = 2048,
model_kwargs: dict = None,
generation_kwargs: dict = dict(),
meta_template: Optional[Dict] = None,
mode: str = 'none',
use_fastchat_template: bool = False,
end_str: Optional[str] = None,
):
super().__init__(path=path,
max_seq_len=max_seq_len,
meta_template=meta_template)
assert LLM, ('Please install VLLM with `pip install vllm`. '
'note: torch==2.1.2 is required.')
self.logger = get_logger()
self._load_model(path, model_kwargs)
self.tokenizer = self.model.get_tokenizer()
self.generation_kwargs = generation_kwargs
self.generation_kwargs.pop('do_sample', None)
assert mode in ['none', 'mid']
self.mode = mode
self.use_fastchat_template = use_fastchat_template
self.end_str = end_str
def _load_model(self,
path: str,
add_model_kwargs: dict = None,
num_retry: int = 3):
model_kwargs = DEFAULT_MODEL_KWARGS.copy()
if add_model_kwargs is not None:
model_kwargs.update(add_model_kwargs)
self.model = LLM(path, **model_kwargs)
def generate(self, inputs: List[str], max_out_len: int,
**kwargs) -> List[str]:
"""Generate results given a list of inputs.
Args:
inputs (List[str]): A list of strings.
max_out_len (int): The maximum length of the output.
Returns:
List[str]: A list of generated strings.
"""
if self.mode == 'mid':
input_ids = self.tokenizer(inputs, truncation=False)['input_ids']
inputs = []
for input_id in input_ids:
if len(input_id) > self.max_seq_len - max_out_len:
half = int((self.max_seq_len - max_out_len) / 2)
inputs.append(
self.tokenizer.decode(input_id[:half],
skip_special_tokens=True) +
self.tokenizer.decode(input_id[-half:],
skip_special_tokens=True))
else:
inputs.append(
self.tokenizer.decode(input_id,
skip_special_tokens=True))
generation_kwargs = kwargs.copy()
generation_kwargs.update(self.generation_kwargs)
generation_kwargs.update({'max_tokens': max_out_len})
sampling_kwargs = SamplingParams(**generation_kwargs)
outputs = self.model.generate(inputs, sampling_kwargs)
prompt_list, output_strs = [], []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
if self.end_str:
generated_text = generated_text.split(self.end_str)[0]
prompt_list.append(prompt)
output_strs.append(generated_text)
return output_strs
def prompts_preproccess(self, inputs: List[str]):
if self.use_fastchat_template:
try:
from fastchat.model import get_conversation_template
except ModuleNotFoundError:
raise ModuleNotFoundError(
'Fastchat is not implemented. You can use '
"'pip install \"fschat[model_worker,webui]\"' "
'to implement fastchat.')
conv = get_conversation_template('vicuna')
conv.append_message(conv.roles[0], inputs[0])
conv.append_message(conv.roles[1], None)
inputs = [conv.get_prompt()]
return inputs
def get_token_len(self, prompt: str) -> int:
"""Get lengths of the tokenized strings.
Args:
prompt (str): Input string.
Returns:
int: Length of the input tokens
"""
return len(self.model.get_tokenizer().encode(prompt))