|
|
|
|
|
import asyncio |
|
|
import inspect |
|
|
import os |
|
|
import time |
|
|
from contextlib import contextmanager |
|
|
from copy import deepcopy |
|
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union |
|
|
|
|
|
import lmdeploy |
|
|
import torch |
|
|
from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig, pipeline |
|
|
from lmdeploy.api import autoget_backend_config |
|
|
from lmdeploy.serve import async_engine |
|
|
from packaging import version |
|
|
from transformers import GenerationConfig |
|
|
|
|
|
from swift.llm import InferRequest, Template, TemplateMeta, get_model_tokenizer |
|
|
from swift.plugin import Metric |
|
|
from swift.utils import get_logger, get_seed |
|
|
from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, |
|
|
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig) |
|
|
from .infer_engine import InferEngine |
|
|
from .patch import patch_auto_config, patch_auto_tokenizer |
|
|
from .utils import InferStreamer, patch_lmdeploy |
|
|
|
|
|
try: |
|
|
from lmdeploy import EngineGenerationConfig as LmdeployGenerationConfig |
|
|
except ImportError: |
|
|
|
|
|
from lmdeploy import GenerationConfig as LmdeployGenerationConfig |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class LmdeployEngine(InferEngine): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_id_or_path: str, |
|
|
torch_dtype: Optional[torch.dtype] = None, |
|
|
*, |
|
|
model_type: Optional[str] = None, |
|
|
use_hf: Optional[bool] = None, |
|
|
hub_token: Optional[str] = None, |
|
|
revision: Optional[str] = None, |
|
|
|
|
|
tp: int = 1, |
|
|
session_len: Optional[int] = None, |
|
|
cache_max_entry_count: float = 0.8, |
|
|
quant_policy: int = 0, |
|
|
vision_batch_size: int = 1, |
|
|
devices: Optional[List[int]] = None, |
|
|
reload_weights: bool = False, |
|
|
engine_kwargs: Optional[Dict[str, Any]] = None, |
|
|
) -> None: |
|
|
version_7 = version.parse(lmdeploy.__version__) >= version.parse('0.7.0') |
|
|
if reload_weights: |
|
|
assert version_7, 'grpo or reload_weights need lmdeploy>=0.7.0' |
|
|
if version_7 and tp == 1: |
|
|
patch_lmdeploy(reload_weights) |
|
|
self.processor = get_model_tokenizer( |
|
|
model_id_or_path, |
|
|
torch_dtype, |
|
|
load_model=False, |
|
|
download_model=True, |
|
|
model_type=model_type, |
|
|
use_hf=use_hf, |
|
|
hub_token=hub_token, |
|
|
revision=revision)[1] |
|
|
self._post_init() |
|
|
|
|
|
if self.max_model_len is not None: |
|
|
self.max_model_len -= 1 |
|
|
self._prepare_engine_kwargs( |
|
|
tp=tp, |
|
|
session_len=session_len, |
|
|
cache_max_entry_count=cache_max_entry_count, |
|
|
quant_policy=quant_policy, |
|
|
vision_batch_size=vision_batch_size, |
|
|
devices=devices, |
|
|
engine_kwargs=engine_kwargs) |
|
|
|
|
|
self.config.torch_dtype = torch_dtype or self.model_info.torch_dtype |
|
|
|
|
|
@contextmanager |
|
|
def disable_deepspeed(): |
|
|
from transformers import modeling_utils |
|
|
modeling_utils.is_deepspeed_zero3_enabled_origin = modeling_utils.is_deepspeed_zero3_enabled |
|
|
modeling_utils.is_deepspeed_zero3_enabled = lambda: False |
|
|
yield |
|
|
modeling_utils.is_deepspeed_zero3_enabled = modeling_utils.is_deepspeed_zero3_enabled_origin |
|
|
del modeling_utils.is_deepspeed_zero3_enabled_origin |
|
|
|
|
|
with disable_deepspeed(): |
|
|
self._prepare_engine() |
|
|
self._load_generation_config() |
|
|
|
|
|
def _prepare_engine_kwargs(self, |
|
|
tp: int = 1, |
|
|
session_len: Optional[int] = None, |
|
|
cache_max_entry_count: float = 0.8, |
|
|
quant_policy: int = 0, |
|
|
vision_batch_size: int = 1, |
|
|
devices: Optional[List[int]] = None, |
|
|
engine_kwargs: Optional[Dict[str, Any]] = None): |
|
|
if engine_kwargs is None: |
|
|
engine_kwargs = {} |
|
|
engine_kwargs['tp'] = tp |
|
|
engine_kwargs['session_len'] = session_len |
|
|
engine_kwargs['cache_max_entry_count'] = cache_max_entry_count |
|
|
engine_kwargs['quant_policy'] = quant_policy |
|
|
backend_config = TurbomindEngineConfig(**engine_kwargs) |
|
|
backend_config = autoget_backend_config(self.model_dir, backend_config) |
|
|
if hasattr(backend_config, 'devices'): |
|
|
if devices is None: |
|
|
devices = [0] |
|
|
backend_config.devices = devices |
|
|
self.backend_config = backend_config |
|
|
logger.info(f'backend_config: {backend_config}') |
|
|
|
|
|
pipeline_kwargs = {} |
|
|
is_multimodal = self.model_meta.is_multimodal |
|
|
if is_multimodal: |
|
|
vision_config = VisionConfig(max_batch_size=vision_batch_size) |
|
|
pipeline_kwargs['vision_config'] = vision_config |
|
|
logger.info(f'vision_config: {vision_config}') |
|
|
self.pipeline_kwargs = pipeline_kwargs |
|
|
|
|
|
@contextmanager |
|
|
def _patch_pipeline(self): |
|
|
_old_best_match_model = async_engine.best_match_model |
|
|
|
|
|
def _best_match_model(*args, **kwargs) -> Optional[str]: |
|
|
return self.model_info.model_type |
|
|
|
|
|
async_engine.best_match_model = _best_match_model |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
async_engine.best_match_model = _old_best_match_model |
|
|
|
|
|
def _prepare_engine(self): |
|
|
with patch_auto_tokenizer(self.tokenizer), patch_auto_config(self.config), self._patch_pipeline(): |
|
|
engine = pipeline(self.model_dir, backend_config=self.backend_config, **self.pipeline_kwargs) |
|
|
self.engine = engine |
|
|
|
|
|
def _load_generation_config(self): |
|
|
generation_config_path = os.path.join(self.model_dir, 'generation_config.json') |
|
|
if os.path.isfile(generation_config_path): |
|
|
generation_config = GenerationConfig.from_pretrained(self.model_dir) |
|
|
kwargs = generation_config.to_dict() |
|
|
max_new_tokens = kwargs.get('max_new_tokens') |
|
|
if max_new_tokens is None: |
|
|
kwargs.pop('max_new_tokens', None) |
|
|
parameters = inspect.signature(LmdeployGenerationConfig).parameters |
|
|
for k, v in kwargs.copy().items(): |
|
|
if k not in parameters or v is None: |
|
|
kwargs.pop(k) |
|
|
self.generation_config = LmdeployGenerationConfig(**kwargs) |
|
|
else: |
|
|
self.generation_config = LmdeployGenerationConfig() |
|
|
|
|
|
def _get_stop_token_ids(self, stop_words: List[Union[str, List[int], None]]) -> List[int]: |
|
|
stop_token_ids: List[int] = [] |
|
|
for stop_word in stop_words: |
|
|
if stop_word is None: |
|
|
continue |
|
|
if isinstance(stop_word, str): |
|
|
stop_word = self.tokenizer.encode(stop_word, add_special_tokens=False) |
|
|
if isinstance(stop_word, list): |
|
|
if len(stop_word) != 1: |
|
|
continue |
|
|
else: |
|
|
stop_token = stop_word[0] |
|
|
elif isinstance(stop_word, int): |
|
|
stop_token = stop_word |
|
|
assert isinstance(stop_token, int) |
|
|
if stop_token not in stop_token_ids: |
|
|
stop_token_ids.append(stop_token) |
|
|
return stop_token_ids |
|
|
|
|
|
def _add_stop_words(self, generation_config: LmdeployGenerationConfig, request_config: RequestConfig, |
|
|
template_meta: TemplateMeta) -> None: |
|
|
stop_words = (request_config.stop or []) + (self.generation_config.stop_words or []) + template_meta.stop_words |
|
|
generation_config.stop_words = self._get_stop_token_ids(stop_words) |
|
|
|
|
|
generation_config.stop_token_ids = generation_config.stop_words |
|
|
|
|
|
def _prepare_generation_config(self, request_config: RequestConfig) -> LmdeployGenerationConfig: |
|
|
kwargs = {'max_new_tokens': request_config.max_tokens} |
|
|
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']: |
|
|
new_value = getattr(request_config, key) |
|
|
if new_value is None: |
|
|
kwargs[key] = getattr(self.generation_config, key) |
|
|
else: |
|
|
kwargs[key] = new_value |
|
|
if request_config.seed is None: |
|
|
request_config.seed = get_seed() |
|
|
kwargs['random_seed'] = request_config.seed |
|
|
if request_config.temperature == 0: |
|
|
kwargs['temperature'] = 1 |
|
|
kwargs['top_k'] = 1 |
|
|
|
|
|
if request_config.logprobs: |
|
|
kwargs['logprobs'] = 1 |
|
|
if request_config.top_logprobs is not None: |
|
|
kwargs['logprobs'] = max(1, request_config.top_logprobs) |
|
|
|
|
|
res = LmdeployGenerationConfig(**kwargs) |
|
|
res.top_logprobs = request_config.top_logprobs |
|
|
return res |
|
|
|
|
|
async def _infer_stream_async( |
|
|
self, template: Template, inputs: Dict[str, Any], |
|
|
generation_config: LmdeployGenerationConfig) -> AsyncIterator[ChatCompletionStreamResponse]: |
|
|
session_id = time.time_ns() |
|
|
kwargs = {'stream_output': True, 'gen_config': generation_config, 'sequence_start': True, 'sequence_end': True} |
|
|
if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'): |
|
|
async with self.engine.model_inst(session_id) as inst: |
|
|
context = self.engine.safe_run(inst, session_id, **inputs, **kwargs) |
|
|
else: |
|
|
context = self.engine.safe_run(session_id) |
|
|
|
|
|
infer_streamer = InferStreamer(template) |
|
|
token_idx = 0 |
|
|
async with context as gen: |
|
|
if version.parse(lmdeploy.__version__) < version.parse('0.6.5'): |
|
|
generator = await self.engine.get_generator(False, session_id) |
|
|
gen = generator.async_stream_infer(session_id=session_id, **inputs, **kwargs) |
|
|
is_finished = False |
|
|
while not is_finished: |
|
|
try: |
|
|
output = await gen.__anext__() |
|
|
except StopAsyncIteration: |
|
|
is_finished = True |
|
|
delta_text = infer_streamer.get_printable_text(output.token_ids, is_finished) |
|
|
if not delta_text and not is_finished: |
|
|
continue |
|
|
|
|
|
logprobs = self._get_logprobs(output.logprobs, output.token_ids[token_idx:], |
|
|
generation_config.top_logprobs) |
|
|
token_idx = len(output.token_ids) |
|
|
|
|
|
usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token) |
|
|
toolcall = None |
|
|
if is_finished: |
|
|
toolcall = self._get_toolcall(template.decode(output.token_ids), template) |
|
|
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token, |
|
|
output.status.name == 'FINISH') |
|
|
choices = [ |
|
|
ChatCompletionResponseStreamChoice( |
|
|
index=0, |
|
|
delta=DeltaMessage(role='assistant', content=delta_text, tool_calls=toolcall), |
|
|
finish_reason=finish_reason, |
|
|
logprobs=logprobs) |
|
|
] |
|
|
yield ChatCompletionStreamResponse(model=self.model_name, choices=choices, usage=usage_info) |
|
|
|
|
|
async def _infer_full_async(self, template: Template, inputs: Dict[str, Any], |
|
|
generation_config: LmdeployGenerationConfig) -> ChatCompletionResponse: |
|
|
session_id = time.time_ns() |
|
|
kwargs = {'stream_output': False, 'gen_config': generation_config, 'sequence_start': True, 'sequence_end': True} |
|
|
if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'): |
|
|
async with self.engine.model_inst(session_id) as inst: |
|
|
async with self.engine.safe_run(inst, session_id, **inputs, **kwargs) as gen: |
|
|
async for output in gen: |
|
|
pass |
|
|
if self.engine.backend == 'pytorch': |
|
|
|
|
|
await inst.async_end(session_id) |
|
|
|
|
|
else: |
|
|
async with self.engine.safe_run(session_id): |
|
|
generator = await self.engine.get_generator(False, session_id) |
|
|
async for output in generator.async_stream_infer(session_id=session_id, **inputs, **kwargs): |
|
|
pass |
|
|
|
|
|
response = template.decode(output.token_ids) |
|
|
logprobs = self._get_logprobs(output.logprobs, output.token_ids, generation_config.top_logprobs) |
|
|
|
|
|
usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token) |
|
|
toolcall = self._get_toolcall(response, template) |
|
|
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token, |
|
|
output.status.name == 'FINISH') |
|
|
choices = [ |
|
|
ChatCompletionResponseChoice( |
|
|
index=0, |
|
|
message=ChatMessage(role='assistant', content=response, tool_calls=toolcall), |
|
|
finish_reason=finish_reason, |
|
|
logprobs=logprobs) |
|
|
] |
|
|
return ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info) |
|
|
|
|
|
async def infer_async(self, |
|
|
infer_request: InferRequest, |
|
|
request_config: Optional[RequestConfig] = None, |
|
|
*, |
|
|
template: Optional[Template] = None, |
|
|
pre_infer_hook=None, |
|
|
**kwargs) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]: |
|
|
request_config = deepcopy(request_config or RequestConfig()) |
|
|
if template is None: |
|
|
template = self.default_template |
|
|
|
|
|
template.set_mode('lmdeploy') |
|
|
|
|
|
loop = asyncio.get_running_loop() |
|
|
with torch.inference_mode(): |
|
|
inputs = await loop.run_in_executor(None, template.encode, infer_request) |
|
|
images = inputs.pop('images', None) |
|
|
if images: |
|
|
if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'): |
|
|
messages = self.engine._convert_prompts(('', images)) |
|
|
messages = await self.engine.async_convert_to_pil_images(messages) |
|
|
results = await self.engine.vl_encoder.preprocess(messages) |
|
|
if self.engine.backend == 'turbomind': |
|
|
results = await self.engine.vl_encoder.async_infer(results) |
|
|
inputs['images'] = [result['content'] for result in results if result['role'] == 'forward'][0] |
|
|
await template.prepare_lmdeploy_turbomind_inputs(inputs) |
|
|
else: |
|
|
inputs['images'] = results[1]['content'] |
|
|
await template.prepare_lmdeploy_pytorch_inputs(inputs) |
|
|
else: |
|
|
inputs['images'] = await self.engine.vl_encoder.async_infer(images) |
|
|
await template.prepare_lmdeploy_turbomind_inputs(inputs) |
|
|
|
|
|
self.set_default_max_tokens(request_config, inputs) |
|
|
generation_config = self._prepare_generation_config(request_config) |
|
|
self._add_stop_words(generation_config, request_config, template.template_meta) |
|
|
kwargs.update({'template': template, 'inputs': inputs, 'generation_config': generation_config}) |
|
|
if pre_infer_hook: |
|
|
kwargs = pre_infer_hook(kwargs) |
|
|
if request_config.stream: |
|
|
return self._infer_stream_async(**kwargs) |
|
|
else: |
|
|
return await self._infer_full_async(**kwargs) |
|
|
|
|
|
def _batch_infer_stream(self, *args, **kwargs): |
|
|
if hasattr(self.engine, 'vl_encoder'): |
|
|
self.engine.vl_encoder._loop_task = None |
|
|
if hasattr(self.engine, 'free_insts'): |
|
|
self.engine.free_insts = None |
|
|
return super()._batch_infer_stream(*args, **kwargs) |
|
|
|
|
|
def infer( |
|
|
self, |
|
|
infer_requests: List[InferRequest], |
|
|
request_config: Optional[RequestConfig] = None, |
|
|
metrics: Optional[List[Metric]] = None, |
|
|
*, |
|
|
template: Optional[Template] = None, |
|
|
use_tqdm: Optional[bool] = None, |
|
|
) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]: |
|
|
return super().infer(infer_requests, request_config, metrics, template=template, use_tqdm=use_tqdm) |
|
|
|