|
|
|
|
|
|
|
|
import os |
|
|
import re |
|
|
from collections import OrderedDict |
|
|
from concurrent.futures import ThreadPoolExecutor |
|
|
from contextlib import contextmanager, nullcontext |
|
|
from dataclasses import dataclass |
|
|
from functools import partial |
|
|
from itertools import repeat |
|
|
from queue import Queue |
|
|
from typing import List, Optional, Union |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from packaging import version |
|
|
from transformers import GenerationConfig, LogitsProcessor |
|
|
from transformers.generation.streamers import BaseStreamer |
|
|
|
|
|
from swift.llm.model.register import fix_do_sample_warning |
|
|
from swift.utils import get_current_device, get_device, get_device_count, get_node_setting, set_device |
|
|
from ..protocol import RequestConfig |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AdapterRequest: |
|
|
name: str |
|
|
path: str |
|
|
|
|
|
|
|
|
class InferTools: |
|
|
|
|
|
@staticmethod |
|
|
def _is_chinese_char(cp: int) -> bool: |
|
|
"""Checks whether CP is the codepoint of a CJK character.""" |
|
|
|
|
|
if ((0x4E00 <= cp <= 0x9FFF) or (0x3400 <= cp <= 0x4DBF) or (0x20000 <= cp <= 0x2A6DF) |
|
|
or (0x2A700 <= cp <= 0x2B73F) or (0x2B740 <= cp <= 0x2B81F) or (0x2B820 <= cp <= 0x2CEAF) |
|
|
or (0xF900 <= cp <= 0xFAFF) or (0x2F800 <= cp <= 0x2FA1F)): |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
class InferStreamer(InferTools): |
|
|
|
|
|
def __init__(self, template, **decode_kwargs): |
|
|
self.template = template |
|
|
self.tokenizer = template.tokenizer |
|
|
|
|
|
self.cache_idx = 0 |
|
|
self.print_idx = 0 |
|
|
self.decode_kwargs = decode_kwargs |
|
|
self.first_num_space = -1 |
|
|
self.first_token = True |
|
|
|
|
|
def _align_blank_suffix(self, response: str) -> str: |
|
|
|
|
|
cur_num_space = len(response) - len(response.lstrip(' ')) |
|
|
if self.first_num_space == -1: |
|
|
self.first_num_space = cur_num_space |
|
|
elif cur_num_space < self.first_num_space: |
|
|
response = ' ' * (self.first_num_space - cur_num_space) + response |
|
|
elif cur_num_space > self.first_num_space: |
|
|
response = response[cur_num_space - self.first_num_space:] |
|
|
return response |
|
|
|
|
|
def _get_response(self, response: str, is_finished: bool, token_len: int) -> str: |
|
|
|
|
|
if self.first_token: |
|
|
printable_text = response |
|
|
self.first_token = False |
|
|
elif response.endswith('\n') or is_finished: |
|
|
printable_text = response[self.print_idx:] |
|
|
self.cache_idx += token_len |
|
|
self.first_num_space = -1 |
|
|
self.print_idx = 0 |
|
|
|
|
|
elif len(response) > 0 and self._is_chinese_char(ord(response[-1])): |
|
|
printable_text = response[self.print_idx:] |
|
|
self.print_idx += len(printable_text) |
|
|
|
|
|
|
|
|
else: |
|
|
printable_text = response[self.print_idx:response.rfind(' ') + 1] |
|
|
self.print_idx += len(printable_text) |
|
|
return printable_text |
|
|
|
|
|
def get_printable_text(self, raw_tokens: List[int], is_finished: bool) -> str: |
|
|
raw_tokens = raw_tokens[self.cache_idx:] |
|
|
if self.first_token: |
|
|
raw_tokens = [] |
|
|
response = self.template.decode( |
|
|
raw_tokens, is_finished=is_finished, tokenizer_kwargs=self.decode_kwargs, first_token=self.first_token) |
|
|
response = self._align_blank_suffix(response) |
|
|
return self._get_response(response, is_finished, len(raw_tokens)) |
|
|
|
|
|
|
|
|
class StreamerMixin: |
|
|
|
|
|
def __init__(self): |
|
|
self.queue = Queue() |
|
|
|
|
|
def __iter__(self): |
|
|
return self |
|
|
|
|
|
def __next__(self) -> torch.Tensor: |
|
|
value = self.queue.get() |
|
|
if value is None: |
|
|
raise StopIteration() |
|
|
else: |
|
|
return value |
|
|
|
|
|
|
|
|
class TokensIteratorStreamer(StreamerMixin, BaseStreamer): |
|
|
|
|
|
def put(self, value: torch.Tensor) -> None: |
|
|
self.queue.put(value) |
|
|
|
|
|
def end(self) -> None: |
|
|
self.queue.put(None) |
|
|
|
|
|
|
|
|
class LogitsStreamer(LogitsProcessor): |
|
|
|
|
|
def __init__(self): |
|
|
self.queue = Queue() |
|
|
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
|
|
self.queue.put(scores) |
|
|
return scores |
|
|
|
|
|
|
|
|
def _set_generation_config_default_value(model_generation_config: GenerationConfig, |
|
|
generation_config: GenerationConfig) -> GenerationConfig: |
|
|
for k, v in model_generation_config.to_dict().items(): |
|
|
new_v = getattr(generation_config, k, None) |
|
|
if k in ['max_length']: |
|
|
continue |
|
|
if k in ['no_repeat_ngram_size'] or v is not None and new_v is None: |
|
|
setattr(generation_config, k, v) |
|
|
return generation_config |
|
|
|
|
|
|
|
|
def prepare_generation_config(model_generation_config: Optional[GenerationConfig], request_config: RequestConfig, |
|
|
tokenizer) -> Optional[GenerationConfig]: |
|
|
if model_generation_config is None or request_config is None: |
|
|
return model_generation_config |
|
|
kwargs = {'max_new_tokens': request_config.max_tokens} |
|
|
|
|
|
for key in ['length_penalty']: |
|
|
kwargs[key] = getattr(request_config, key) |
|
|
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty', 'num_beams']: |
|
|
new_value = getattr(request_config, key) |
|
|
if new_value is None: |
|
|
kwargs[key] = getattr(model_generation_config, key) |
|
|
else: |
|
|
kwargs[key] = new_value |
|
|
|
|
|
if not model_generation_config.do_sample and request_config.temperature in {0, None}: |
|
|
kwargs['temperature'] = 0 |
|
|
if kwargs['temperature'] == 0: |
|
|
kwargs['do_sample'] = False |
|
|
kwargs['temperature'] = 1 |
|
|
kwargs['top_p'] = 1 |
|
|
kwargs['top_k'] = 50 |
|
|
else: |
|
|
kwargs['do_sample'] = True |
|
|
generation_config = GenerationConfig(**kwargs) |
|
|
generation_config = _set_generation_config_default_value(model_generation_config, generation_config) |
|
|
fix_do_sample_warning(generation_config) |
|
|
|
|
|
if generation_config.eos_token_id is None: |
|
|
generation_config.eos_token_id = tokenizer.eos_token_id |
|
|
generation_config.pad_token_id = tokenizer.pad_token_id |
|
|
return generation_config |
|
|
|
|
|
|
|
|
def patch_lmdeploy(load_weights=False): |
|
|
"""This patch allows lmdeploy selects device and reload state_dict""" |
|
|
import lmdeploy |
|
|
assert version.parse(lmdeploy.__version__) >= version.parse('0.7.0') |
|
|
from lmdeploy.messages import TurbomindEngineConfig |
|
|
from lmdeploy.turbomind.deploy import loader |
|
|
from lmdeploy.turbomind.deploy.loader import create_loader |
|
|
from lmdeploy.turbomind.deploy.source_model import llama |
|
|
|
|
|
def _create_loader(model_path: str, pattern: str): |
|
|
if not isinstance(model_path, (str, os.PathLike)): |
|
|
|
|
|
def generate(): |
|
|
generator = OrderedDict() |
|
|
model_dict = {} |
|
|
if not isinstance(model_path, dict): |
|
|
for key, value in list(model_path): |
|
|
model_dict[key] = value |
|
|
else: |
|
|
model_dict = model_path |
|
|
for key, value in model_dict.items(): |
|
|
match = re.findall(pattern, key) |
|
|
if not match: |
|
|
if -1 not in generator: |
|
|
generator[-1] = {} |
|
|
generator[-1][key] = value |
|
|
else: |
|
|
layer = int(match[0]) |
|
|
if layer not in generator: |
|
|
generator[layer] = {} |
|
|
generator[layer][key] = value |
|
|
return generator |
|
|
|
|
|
return generate() |
|
|
else: |
|
|
return create_loader(model_path, pattern) |
|
|
|
|
|
loader.create_loader = _create_loader |
|
|
llama.create_loader = _create_loader |
|
|
|
|
|
TurbomindEngineConfig.devices = [0] |
|
|
|
|
|
from lmdeploy.turbomind.turbomind import TurboMind |
|
|
from lmdeploy.turbomind.utils import ModelSource |
|
|
|
|
|
@contextmanager |
|
|
def patch_threadpool_map(): |
|
|
ThreadPoolExecutor.map_origin = ThreadPoolExecutor.map |
|
|
ThreadPoolExecutor.map = lambda *args, **kwargs: [] |
|
|
yield |
|
|
ThreadPoolExecutor.map = ThreadPoolExecutor.map_origin |
|
|
del ThreadPoolExecutor.map_origin |
|
|
|
|
|
@contextmanager |
|
|
def tm_model_context(self): |
|
|
|
|
|
def _get_tm_model(model_path, |
|
|
model_name, |
|
|
chat_template_name, |
|
|
engine_config: TurbomindEngineConfig, |
|
|
group_size: int = None, |
|
|
out_dir: str = None): |
|
|
from lmdeploy.turbomind.deploy.converter import get_tm_model_origin |
|
|
tm_model = get_tm_model_origin(model_path, model_name, chat_template_name, engine_config, group_size, |
|
|
out_dir) |
|
|
self.tm_model = tm_model |
|
|
return tm_model |
|
|
|
|
|
from lmdeploy.turbomind.deploy import converter |
|
|
converter.get_tm_model_origin = converter.get_tm_model |
|
|
converter.get_tm_model = _get_tm_model |
|
|
yield |
|
|
converter.get_tm_model = converter.get_tm_model_origin |
|
|
del converter.get_tm_model_origin |
|
|
|
|
|
def __init__(self, |
|
|
model_path: str, |
|
|
tokenizer: object, |
|
|
model_name: str = None, |
|
|
chat_template_name: str = None, |
|
|
engine_config: TurbomindEngineConfig = None, |
|
|
model_source: ModelSource = ModelSource.WORKSPACE, |
|
|
**kwargs): |
|
|
self.gpu_list = engine_config.devices |
|
|
with patch_threadpool_map(), tm_model_context(self): |
|
|
self.__origin_init__(model_path, tokenizer, model_name, chat_template_name, engine_config, model_source, |
|
|
**kwargs) |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=self.gpu_count) as e: |
|
|
ranks = [self.node_id * self.gpu_count + device_id for device_id in range(self.gpu_count)] |
|
|
if not load_weights: |
|
|
for _ in e.map(self.model_comm.process_weight, self.gpu_list, ranks): |
|
|
pass |
|
|
if version.parse(lmdeploy.__version__) < version.parse('0.7.2'): |
|
|
for _ in e.map(self.model_comm.create_engine, self.gpu_list, ranks, repeat(self.nccl_params)): |
|
|
pass |
|
|
else: |
|
|
for _ in e.map(self.model_comm.create_engine, self.gpu_list, ranks): |
|
|
pass |
|
|
|
|
|
def _create_weight(self, model_comm): |
|
|
"""Allocate weight buffer, load params if from_workspace.""" |
|
|
|
|
|
|
|
|
self.node_id = 0 |
|
|
self.node_num = 1 |
|
|
if version.parse(lmdeploy.__version__) < version.parse('0.7.2'): |
|
|
self.nccl_params = model_comm.create_nccl_params(self.node_id) |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
def _create_weight_func(index, device_id): |
|
|
rank = self.node_id * self.gpu_count + index |
|
|
model_comm.create_shared_weights(device_id, rank) |
|
|
|
|
|
with ThreadPoolExecutor(max_workers=self.gpu_count) as executor: |
|
|
futures = [] |
|
|
for idx, device_id in enumerate(self.gpu_list): |
|
|
futures.append(executor.submit(_create_weight_func, idx, device_id)) |
|
|
for future in futures: |
|
|
future.result() |
|
|
|
|
|
def _get_model_params(self, model_comm, tm_params): |
|
|
"""Get turbomind model params when loading from hf.""" |
|
|
|
|
|
def _get_params(idx, device_id, que): |
|
|
rank = self.node_id * self.gpu_count + idx |
|
|
out = model_comm.get_params(device_id, rank) |
|
|
que.put(out) |
|
|
|
|
|
que = Queue() |
|
|
with ThreadPoolExecutor(max_workers=self.gpu_count) as executor: |
|
|
futures = [] |
|
|
for idx, device_id in enumerate(self.gpu_list): |
|
|
futures.append(executor.submit(_get_params, idx, device_id, que)) |
|
|
for future in futures: |
|
|
future.result() |
|
|
|
|
|
for _ in range(self.gpu_count): |
|
|
tensor_map = que.get() |
|
|
for k, v in tensor_map.items(): |
|
|
if k not in tm_params: |
|
|
tm_params[k] = [] |
|
|
tm_params[k].append(v) |
|
|
|
|
|
def _load_weights(self, state_dict): |
|
|
tm_params = self.tm_model.tm_params |
|
|
self._get_model_params(self.model_comm, tm_params) |
|
|
input_model = self.tm_model.input_model |
|
|
model_path = input_model.model_path |
|
|
input_model.model_path = state_dict |
|
|
self.tm_model.export() |
|
|
input_model.model_path = model_path |
|
|
|
|
|
from lmdeploy.turbomind.turbomind import TurboMindInstance |
|
|
|
|
|
def create_instance(self, cuda_stream_id=0): |
|
|
return TurboMindInstance(self, self.config, cuda_stream_id, self.gpu_list) |
|
|
|
|
|
TurboMind.__origin_init__ = TurboMind.__init__ |
|
|
TurboMind.__init__ = __init__ |
|
|
TurboMind._create_weight = _create_weight |
|
|
TurboMind._get_model_params = _get_model_params |
|
|
TurboMind.create_instance = create_instance |
|
|
if load_weights: |
|
|
TurboMind.load_weights = _load_weights |
|
|
|
|
|
def __init_ins__(self, tm_model, config, cuda_stream_id=0, gpu_list=None): |
|
|
if gpu_list is None: |
|
|
gpu_list = [0] |
|
|
self.gpu_list = gpu_list |
|
|
self.__origin_init__(tm_model, config, cuda_stream_id) |
|
|
|
|
|
def _create_model_instance(self, device_id): |
|
|
model_inst = self.tm_model.model_comm.create_model_instance(self.gpu_list[0]) |
|
|
return model_inst |
|
|
|
|
|
TurboMindInstance.__origin_init__ = TurboMindInstance.__init__ |
|
|
TurboMindInstance.__init__ = __init_ins__ |
|
|
TurboMindInstance._create_model_instance = _create_model_instance |
|
|
|
|
|
|
|
|
def patch_vllm(world_size=1): |
|
|
|
|
|
@contextmanager |
|
|
def _get_context(): |
|
|
from vllm.distributed.parallel_state import GroupCoordinator |
|
|
from unittest.mock import patch |
|
|
try: |
|
|
from vllm.worker.worker import Worker |
|
|
getattr(Worker, '_assert_memory_footprint_increased_during_profiling') |
|
|
profiling_patch = patch( |
|
|
'vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling', return_value=None) |
|
|
except (ImportError, AttributeError): |
|
|
profiling_patch = nullcontext() |
|
|
|
|
|
__origin_init__ = GroupCoordinator.__init__ |
|
|
|
|
|
def get_world_size(group=None) -> int: |
|
|
if not group: |
|
|
|
|
|
return world_size |
|
|
else: |
|
|
return torch.distributed.get_world_size_origin(group) |
|
|
|
|
|
def __init__(self, group_ranks, local_rank, *args, **kwargs): |
|
|
node_rank, nnodes = get_node_setting() |
|
|
device_count = get_device_count() |
|
|
num_infer_workers = world_size // nnodes |
|
|
|
|
|
def map_rank_to_real_device(obj): |
|
|
|
|
|
|
|
|
diff = device_count - num_infer_workers |
|
|
if diff < 0: |
|
|
diff = 0 |
|
|
if isinstance(obj, list): |
|
|
return [map_rank_to_real_device(o) for o in obj] |
|
|
elif isinstance(obj, int): |
|
|
return obj + diff |
|
|
else: |
|
|
raise ValueError(f'Unsupported type: {obj}') |
|
|
|
|
|
if kwargs.get('group_name') == 'world': |
|
|
local_rank = local_rank + node_rank * num_infer_workers |
|
|
else: |
|
|
local_rank = map_rank_to_real_device(local_rank - node_rank * num_infer_workers) |
|
|
rank = dist.get_rank() |
|
|
if world_size == 1 and [rank] not in group_ranks: |
|
|
|
|
|
group_ranks = [[rank]] |
|
|
if nnodes > 1 and num_infer_workers < device_count: |
|
|
""" |
|
|
Map group_ranks to global ranks |
|
|
|
|
|
Example: |
|
|
- Number of nodes (nnodes): 2 |
|
|
- Devices per node (device_count): 4 |
|
|
- Inference workers per node (num_infer_workers): 1 |
|
|
|
|
|
Initial group_ranks: |
|
|
[[0, 1]] |
|
|
|
|
|
After mapping to global ranks: |
|
|
[[0, 3]] # Global ranks corresponding to the local ranks |
|
|
""" |
|
|
train_device_count = device_count - num_infer_workers |
|
|
|
|
|
if len(group_ranks) == 1: |
|
|
group_ranks = group_ranks[0] |
|
|
for i in range(nnodes): |
|
|
group_ranks[i * num_infer_workers:(i + 1) * num_infer_workers] = [ |
|
|
train_device_count * i + j for j in range(num_infer_workers) |
|
|
] |
|
|
group_ranks = [group_ranks] |
|
|
|
|
|
else: |
|
|
for i in range(nnodes): |
|
|
for j in range(num_infer_workers): |
|
|
group_ranks[i * num_infer_workers + j] = [train_device_count * i + j] |
|
|
|
|
|
return __origin_init__(self, group_ranks, local_rank, *args, **kwargs) |
|
|
|
|
|
GroupCoordinator.__init__ = __init__ |
|
|
|
|
|
try: |
|
|
with profiling_patch, restore_torch_device_after_vllm_init(): |
|
|
torch.distributed.get_world_size_origin = torch.distributed.get_world_size |
|
|
torch.distributed.get_world_size = get_world_size |
|
|
yield |
|
|
torch.distributed.get_world_size = torch.distributed.get_world_size_origin |
|
|
del torch.distributed.get_world_size_origin |
|
|
finally: |
|
|
GroupCoordinator.__init__ = __origin_init__ |
|
|
|
|
|
return _get_context() if dist.is_initialized() else nullcontext() |
|
|
|
|
|
|
|
|
def patch_npu_vllm(vllm_device: str): |
|
|
if isinstance(vllm_device, int): |
|
|
vllm_device = get_device(vllm_device) |
|
|
device_type = vllm_device.split(':')[0] |
|
|
|
|
|
@contextmanager |
|
|
def new_group_context(): |
|
|
original_new_group = torch.distributed.new_group |
|
|
try: |
|
|
torch.distributed.new_group = partial(original_new_group, use_local_synchronization=True) |
|
|
torch.npu.mem_get_info = partial(torch.npu.mem_get_info, device=vllm_device) |
|
|
yield |
|
|
finally: |
|
|
torch.distributed.new_group = original_new_group |
|
|
|
|
|
return new_group_context() if device_type == 'npu' else nullcontext() |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def set_device_context(device: Union[str, int]): |
|
|
origin_device = get_current_device() |
|
|
set_device(device) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
set_device(origin_device) |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def restore_torch_device_after_vllm_init(): |
|
|
""" |
|
|
A context manager to restore the original CUDA device after potential modifications. |
|
|
|
|
|
This is specifically designed to address an issue in Distributed Data Parallel (DDP) |
|
|
scenarios where the initialization of the vLLM engine may inadvertently modify the |
|
|
default CUDA device. The context manager saves the current device at the start and |
|
|
ensures it is restored upon exit, even if the device is modified within the context. |
|
|
|
|
|
""" |
|
|
origin_device = get_current_device() |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
current_device = get_current_device() |
|
|
if origin_device != current_device: |
|
|
set_device(origin_device) |
|
|
|
|
|
|
|
|
def patch_vllm_memory_leak(): |
|
|
import vllm |
|
|
if version.parse(vllm.__version__) != version.parse('0.7.3'): |
|
|
return |
|
|
|
|
|
def patch_vllm_abort_seq_group(): |
|
|
from vllm.core.scheduler import Scheduler |
|
|
from typing import Iterable, Dict |
|
|
from vllm.sequence import SequenceGroupBase, SequenceGroup, SequenceStatus |
|
|
|
|
|
def new_abort_seq_group( |
|
|
self, |
|
|
request_id: Union[str, Iterable[str]], |
|
|
seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None, |
|
|
) -> None: |
|
|
if isinstance(request_id, str): |
|
|
request_id = (request_id, ) |
|
|
request_ids = set(request_id) |
|
|
seq_id_to_seq_group = seq_id_to_seq_group or {} |
|
|
for state_queue in [self.waiting, self.running, self.swapped]: |
|
|
aborted_groups: List[SequenceGroup] = [] |
|
|
for seq_group in state_queue: |
|
|
|
|
|
|
|
|
|
|
|
if seq_group.request_id in seq_id_to_seq_group: |
|
|
real_request_id = seq_id_to_seq_group[seq_group.request_id].group_id |
|
|
else: |
|
|
real_request_id = seq_group.request_id |
|
|
if real_request_id in request_ids: |
|
|
|
|
|
aborted_groups.append(seq_group) |
|
|
|
|
|
|
|
|
|
|
|
for aborted_group in aborted_groups: |
|
|
|
|
|
state_queue.remove(aborted_group) |
|
|
|
|
|
self._finished_requests_ids.append(aborted_group.request_id) |
|
|
for seq in aborted_group.get_seqs(): |
|
|
if seq.is_finished(): |
|
|
continue |
|
|
seq.status = SequenceStatus.FINISHED_ABORTED |
|
|
self.free_seq(seq) |
|
|
if aborted_group.request_id in seq_id_to_seq_group: |
|
|
del seq_id_to_seq_group[aborted_group.request_id] |
|
|
|
|
|
self._free_seq_group_cross_attn_blocks(aborted_group) |
|
|
|
|
|
origin_method = Scheduler.abort_seq_group |
|
|
Scheduler._old_abort_seq_group = origin_method |
|
|
Scheduler.abort_seq_group = new_abort_seq_group |
|
|
|
|
|
def patch_vllm_engine(): |
|
|
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState |
|
|
from vllm.outputs import PoolingRequestOutput, RequestOutput |
|
|
from vllm.sequence import ExecuteModelRequest |
|
|
|
|
|
def new_abort_request(self, request_id) -> None: |
|
|
for scheduler in self.scheduler: |
|
|
scheduler.abort_seq_group(request_id, seq_id_to_seq_group=self.seq_id_to_seq_group) |
|
|
|
|
|
origin_method = LLMEngine.abort_request |
|
|
LLMEngine._old_abort_request = origin_method |
|
|
LLMEngine.abort_request = new_abort_request |
|
|
|
|
|
def new_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]: |
|
|
if self.parallel_config.pipeline_parallel_size > 1: |
|
|
raise NotImplementedError('Pipeline parallelism is only supported through AsyncLLMEngine ' |
|
|
'as performance will be severely degraded otherwise.') |
|
|
|
|
|
|
|
|
|
|
|
virtual_engine = 0 |
|
|
|
|
|
|
|
|
|
|
|
cached_outputs = self.cached_scheduler_outputs[virtual_engine] |
|
|
seq_group_metadata_list = cached_outputs.seq_group_metadata_list |
|
|
scheduler_outputs = cached_outputs.scheduler_outputs |
|
|
allow_async_output_proc = cached_outputs.allow_async_output_proc |
|
|
|
|
|
ctx = self.scheduler_contexts[virtual_engine] |
|
|
|
|
|
|
|
|
ctx.request_outputs.clear() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self._has_remaining_steps(seq_group_metadata_list): |
|
|
|
|
|
(seq_group_metadata_list, scheduler_outputs, |
|
|
allow_async_output_proc) = self.scheduler[virtual_engine].schedule() |
|
|
|
|
|
ctx.seq_group_metadata_list = seq_group_metadata_list |
|
|
ctx.scheduler_outputs = scheduler_outputs |
|
|
|
|
|
finished_requests_ids = self.scheduler[virtual_engine].get_and_reset_finished_requests_ids() |
|
|
|
|
|
|
|
|
for finished_request_id in finished_requests_ids: |
|
|
if finished_request_id in self.seq_id_to_seq_group: |
|
|
del self.seq_id_to_seq_group[finished_request_id] |
|
|
|
|
|
|
|
|
if not allow_async_output_proc and len(ctx.output_queue) > 0: |
|
|
self._process_model_outputs(ctx=ctx) |
|
|
|
|
|
if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): |
|
|
|
|
|
|
|
|
self._cache_scheduler_outputs_for_multi_step(virtual_engine, seq_group_metadata_list, |
|
|
scheduler_outputs, allow_async_output_proc) |
|
|
else: |
|
|
finished_requests_ids = list() |
|
|
|
|
|
assert seq_group_metadata_list is not None |
|
|
assert scheduler_outputs is not None |
|
|
|
|
|
if not scheduler_outputs.is_empty(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
last_sampled_token_ids = \ |
|
|
self._get_last_sampled_token_ids(virtual_engine) |
|
|
|
|
|
execute_model_req = ExecuteModelRequest( |
|
|
seq_group_metadata_list=seq_group_metadata_list, |
|
|
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, |
|
|
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, |
|
|
blocks_to_copy=scheduler_outputs.blocks_to_copy, |
|
|
num_lookahead_slots=scheduler_outputs.num_lookahead_slots, |
|
|
running_queue_size=scheduler_outputs.running_queue_size, |
|
|
finished_requests_ids=finished_requests_ids, |
|
|
|
|
|
|
|
|
last_sampled_token_ids=last_sampled_token_ids) |
|
|
|
|
|
if allow_async_output_proc: |
|
|
execute_model_req.async_callback = self.async_callbacks[virtual_engine] |
|
|
|
|
|
outputs = self.model_executor.execute_model(execute_model_req=execute_model_req) |
|
|
|
|
|
|
|
|
|
|
|
if self.scheduler_config.is_multi_step: |
|
|
self._update_cached_scheduler_output(virtual_engine, outputs) |
|
|
else: |
|
|
|
|
|
|
|
|
if len(ctx.output_queue) > 0: |
|
|
self._process_model_outputs(ctx=ctx) |
|
|
|
|
|
outputs = [] |
|
|
|
|
|
|
|
|
if self.scheduler_config.is_multi_step: |
|
|
for seq_group in seq_group_metadata_list: |
|
|
seq_group.finish_step() |
|
|
|
|
|
if not self._has_remaining_steps(seq_group_metadata_list): |
|
|
|
|
|
if self.scheduler_config.is_multi_step: |
|
|
self.cached_scheduler_outputs[0] = SchedulerOutputState() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_first_step_output: bool = False if not seq_group_metadata_list \ |
|
|
else seq_group_metadata_list[0].state.num_steps == 1 |
|
|
|
|
|
|
|
|
ctx.append_output( |
|
|
outputs=outputs, |
|
|
seq_group_metadata_list=seq_group_metadata_list, |
|
|
scheduler_outputs=scheduler_outputs, |
|
|
is_async=allow_async_output_proc, |
|
|
is_last_step=True, |
|
|
is_first_step_output=is_first_step_output) |
|
|
|
|
|
if outputs and allow_async_output_proc: |
|
|
assert len(outputs) == 1, ('Async postprocessor expects only a single output set') |
|
|
|
|
|
self._advance_to_next_step(outputs[0], seq_group_metadata_list, |
|
|
scheduler_outputs.scheduled_seq_groups) |
|
|
|
|
|
|
|
|
if not allow_async_output_proc: |
|
|
self._process_model_outputs(ctx=ctx) |
|
|
|
|
|
|
|
|
self.do_log_stats(scheduler_outputs, outputs) |
|
|
|
|
|
|
|
|
self.do_tracing(scheduler_outputs) |
|
|
else: |
|
|
|
|
|
return ctx.request_outputs |
|
|
|
|
|
if not self.has_unfinished_requests(): |
|
|
|
|
|
if len(ctx.output_queue) > 0: |
|
|
self._process_model_outputs(ctx=ctx) |
|
|
assert len(ctx.output_queue) == 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.model_executor.stop_remote_worker_execution_loop() |
|
|
|
|
|
return ctx.request_outputs |
|
|
|
|
|
origin_method = LLMEngine.step |
|
|
LLMEngine._old_step = origin_method |
|
|
LLMEngine.step = new_step |
|
|
|
|
|
patch_vllm_abort_seq_group() |
|
|
patch_vllm_engine() |
|
|
|