File size: 31,282 Bytes
cb2428f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import re
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from functools import partial
from itertools import repeat
from queue import Queue
from typing import List, Optional, Union
import torch
import torch.distributed as dist
from packaging import version
from transformers import GenerationConfig, LogitsProcessor
from transformers.generation.streamers import BaseStreamer
from swift.llm.model.register import fix_do_sample_warning
from swift.utils import get_current_device, get_device, get_device_count, get_node_setting, set_device
from ..protocol import RequestConfig
@dataclass
class AdapterRequest:
name: str
path: str
class InferTools:
@staticmethod
def _is_chinese_char(cp: int) -> bool:
"""Checks whether CP is the codepoint of a CJK character."""
# copy from transformers.generation.streamers.TextStreamer
if ((0x4E00 <= cp <= 0x9FFF) or (0x3400 <= cp <= 0x4DBF) or (0x20000 <= cp <= 0x2A6DF)
or (0x2A700 <= cp <= 0x2B73F) or (0x2B740 <= cp <= 0x2B81F) or (0x2B820 <= cp <= 0x2CEAF)
or (0xF900 <= cp <= 0xFAFF) or (0x2F800 <= cp <= 0x2FA1F)):
return True
return False
class InferStreamer(InferTools):
def __init__(self, template, **decode_kwargs):
self.template = template
self.tokenizer = template.tokenizer
self.cache_idx = 0 # token idx
self.print_idx = 0
self.decode_kwargs = decode_kwargs
self.first_num_space = -1 # The number of whitespace characters before the first token.
self.first_token = True
def _align_blank_suffix(self, response: str) -> str:
# Avoid the occurrence of repeated words in sentence.
cur_num_space = len(response) - len(response.lstrip(' '))
if self.first_num_space == -1:
self.first_num_space = cur_num_space
elif cur_num_space < self.first_num_space:
response = ' ' * (self.first_num_space - cur_num_space) + response
elif cur_num_space > self.first_num_space:
response = response[cur_num_space - self.first_num_space:]
return response
def _get_response(self, response: str, is_finished: bool, token_len: int) -> str:
# After the symbol for a new line, we flush the cache.
if self.first_token:
printable_text = response
self.first_token = False
elif response.endswith('\n') or is_finished:
printable_text = response[self.print_idx:]
self.cache_idx += token_len
self.first_num_space = -1
self.print_idx = 0
# If the last token is a CJK character, we print the characters.
elif len(response) > 0 and self._is_chinese_char(ord(response[-1])):
printable_text = response[self.print_idx:]
self.print_idx += len(printable_text)
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = response[self.print_idx:response.rfind(' ') + 1]
self.print_idx += len(printable_text)
return printable_text
def get_printable_text(self, raw_tokens: List[int], is_finished: bool) -> str:
raw_tokens = raw_tokens[self.cache_idx:]
if self.first_token:
raw_tokens = []
response = self.template.decode(
raw_tokens, is_finished=is_finished, tokenizer_kwargs=self.decode_kwargs, first_token=self.first_token)
response = self._align_blank_suffix(response)
return self._get_response(response, is_finished, len(raw_tokens))
class StreamerMixin:
def __init__(self):
self.queue = Queue()
def __iter__(self):
return self
def __next__(self) -> torch.Tensor:
value = self.queue.get()
if value is None:
raise StopIteration()
else:
return value
class TokensIteratorStreamer(StreamerMixin, BaseStreamer):
def put(self, value: torch.Tensor) -> None:
self.queue.put(value)
def end(self) -> None:
self.queue.put(None)
class LogitsStreamer(LogitsProcessor):
def __init__(self):
self.queue = Queue()
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
self.queue.put(scores)
return scores
def _set_generation_config_default_value(model_generation_config: GenerationConfig,
generation_config: GenerationConfig) -> GenerationConfig:
for k, v in model_generation_config.to_dict().items():
new_v = getattr(generation_config, k, None)
if k in ['max_length']:
continue
if k in ['no_repeat_ngram_size'] or v is not None and new_v is None:
setattr(generation_config, k, v)
return generation_config
def prepare_generation_config(model_generation_config: Optional[GenerationConfig], request_config: RequestConfig,
tokenizer) -> Optional[GenerationConfig]:
if model_generation_config is None or request_config is None:
return model_generation_config
kwargs = {'max_new_tokens': request_config.max_tokens}
# not use: 'n', 'best_of', 'frequency_penalty', 'presence_penalty'
for key in ['length_penalty']:
kwargs[key] = getattr(request_config, key)
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty', 'num_beams']:
new_value = getattr(request_config, key)
if new_value is None:
kwargs[key] = getattr(model_generation_config, key)
else:
kwargs[key] = new_value
if not model_generation_config.do_sample and request_config.temperature in {0, None}:
kwargs['temperature'] = 0
if kwargs['temperature'] == 0:
kwargs['do_sample'] = False
kwargs['temperature'] = 1
kwargs['top_p'] = 1
kwargs['top_k'] = 50
else:
kwargs['do_sample'] = True
generation_config = GenerationConfig(**kwargs)
generation_config = _set_generation_config_default_value(model_generation_config, generation_config)
fix_do_sample_warning(generation_config)
if generation_config.eos_token_id is None:
generation_config.eos_token_id = tokenizer.eos_token_id
generation_config.pad_token_id = tokenizer.pad_token_id
return generation_config
def patch_lmdeploy(load_weights=False):
"""This patch allows lmdeploy selects device and reload state_dict"""
import lmdeploy
assert version.parse(lmdeploy.__version__) >= version.parse('0.7.0')
from lmdeploy.messages import TurbomindEngineConfig
from lmdeploy.turbomind.deploy import loader
from lmdeploy.turbomind.deploy.loader import create_loader
from lmdeploy.turbomind.deploy.source_model import llama
def _create_loader(model_path: str, pattern: str):
if not isinstance(model_path, (str, os.PathLike)):
def generate():
generator = OrderedDict()
model_dict = {}
if not isinstance(model_path, dict):
for key, value in list(model_path):
model_dict[key] = value
else:
model_dict = model_path
for key, value in model_dict.items():
match = re.findall(pattern, key)
if not match:
if -1 not in generator:
generator[-1] = {}
generator[-1][key] = value
else:
layer = int(match[0])
if layer not in generator:
generator[layer] = {}
generator[layer][key] = value
return generator
return generate()
else:
return create_loader(model_path, pattern)
loader.create_loader = _create_loader
llama.create_loader = _create_loader
TurbomindEngineConfig.devices = [0]
from lmdeploy.turbomind.turbomind import TurboMind
from lmdeploy.turbomind.utils import ModelSource
@contextmanager
def patch_threadpool_map():
ThreadPoolExecutor.map_origin = ThreadPoolExecutor.map
ThreadPoolExecutor.map = lambda *args, **kwargs: []
yield
ThreadPoolExecutor.map = ThreadPoolExecutor.map_origin
del ThreadPoolExecutor.map_origin
@contextmanager
def tm_model_context(self):
def _get_tm_model(model_path,
model_name,
chat_template_name,
engine_config: TurbomindEngineConfig,
group_size: int = None,
out_dir: str = None):
from lmdeploy.turbomind.deploy.converter import get_tm_model_origin
tm_model = get_tm_model_origin(model_path, model_name, chat_template_name, engine_config, group_size,
out_dir)
self.tm_model = tm_model
return tm_model
from lmdeploy.turbomind.deploy import converter
converter.get_tm_model_origin = converter.get_tm_model
converter.get_tm_model = _get_tm_model
yield
converter.get_tm_model = converter.get_tm_model_origin
del converter.get_tm_model_origin
def __init__(self,
model_path: str,
tokenizer: object,
model_name: str = None,
chat_template_name: str = None,
engine_config: TurbomindEngineConfig = None,
model_source: ModelSource = ModelSource.WORKSPACE,
**kwargs):
self.gpu_list = engine_config.devices
with patch_threadpool_map(), tm_model_context(self):
self.__origin_init__(model_path, tokenizer, model_name, chat_template_name, engine_config, model_source,
**kwargs)
with ThreadPoolExecutor(max_workers=self.gpu_count) as e:
ranks = [self.node_id * self.gpu_count + device_id for device_id in range(self.gpu_count)]
if not load_weights:
for _ in e.map(self.model_comm.process_weight, self.gpu_list, ranks):
pass
if version.parse(lmdeploy.__version__) < version.parse('0.7.2'):
for _ in e.map(self.model_comm.create_engine, self.gpu_list, ranks, repeat(self.nccl_params)):
pass
else:
for _ in e.map(self.model_comm.create_engine, self.gpu_list, ranks):
pass
def _create_weight(self, model_comm):
"""Allocate weight buffer, load params if from_workspace."""
# TODO: support mpi
self.node_id = 0
self.node_num = 1
if version.parse(lmdeploy.__version__) < version.parse('0.7.2'):
self.nccl_params = model_comm.create_nccl_params(self.node_id)
torch.cuda.synchronize()
# create weight
def _create_weight_func(index, device_id):
rank = self.node_id * self.gpu_count + index
model_comm.create_shared_weights(device_id, rank)
with ThreadPoolExecutor(max_workers=self.gpu_count) as executor:
futures = []
for idx, device_id in enumerate(self.gpu_list):
futures.append(executor.submit(_create_weight_func, idx, device_id))
for future in futures:
future.result()
def _get_model_params(self, model_comm, tm_params):
"""Get turbomind model params when loading from hf."""
def _get_params(idx, device_id, que):
rank = self.node_id * self.gpu_count + idx
out = model_comm.get_params(device_id, rank)
que.put(out)
que = Queue()
with ThreadPoolExecutor(max_workers=self.gpu_count) as executor:
futures = []
for idx, device_id in enumerate(self.gpu_list):
futures.append(executor.submit(_get_params, idx, device_id, que))
for future in futures:
future.result()
for _ in range(self.gpu_count):
tensor_map = que.get()
for k, v in tensor_map.items():
if k not in tm_params:
tm_params[k] = []
tm_params[k].append(v)
def _load_weights(self, state_dict):
tm_params = self.tm_model.tm_params
self._get_model_params(self.model_comm, tm_params)
input_model = self.tm_model.input_model
model_path = input_model.model_path
input_model.model_path = state_dict
self.tm_model.export()
input_model.model_path = model_path
from lmdeploy.turbomind.turbomind import TurboMindInstance
def create_instance(self, cuda_stream_id=0):
return TurboMindInstance(self, self.config, cuda_stream_id, self.gpu_list)
TurboMind.__origin_init__ = TurboMind.__init__
TurboMind.__init__ = __init__
TurboMind._create_weight = _create_weight
TurboMind._get_model_params = _get_model_params
TurboMind.create_instance = create_instance
if load_weights:
TurboMind.load_weights = _load_weights
def __init_ins__(self, tm_model, config, cuda_stream_id=0, gpu_list=None):
if gpu_list is None:
gpu_list = [0]
self.gpu_list = gpu_list
self.__origin_init__(tm_model, config, cuda_stream_id)
def _create_model_instance(self, device_id):
model_inst = self.tm_model.model_comm.create_model_instance(self.gpu_list[0])
return model_inst
TurboMindInstance.__origin_init__ = TurboMindInstance.__init__
TurboMindInstance.__init__ = __init_ins__
TurboMindInstance._create_model_instance = _create_model_instance
def patch_vllm(world_size=1):
@contextmanager
def _get_context():
from vllm.distributed.parallel_state import GroupCoordinator
from unittest.mock import patch
try:
from vllm.worker.worker import Worker
getattr(Worker, '_assert_memory_footprint_increased_during_profiling')
profiling_patch = patch(
'vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling', return_value=None)
except (ImportError, AttributeError):
profiling_patch = nullcontext()
__origin_init__ = GroupCoordinator.__init__
def get_world_size(group=None) -> int:
if not group:
# Given size
return world_size
else:
return torch.distributed.get_world_size_origin(group)
def __init__(self, group_ranks, local_rank, *args, **kwargs):
node_rank, nnodes = get_node_setting()
device_count = get_device_count()
num_infer_workers = world_size // nnodes
def map_rank_to_real_device(obj):
# Use the last devices
# world_size=4 gpus=8 [0,1,2,3] will map to [4,5,6,7]
diff = device_count - num_infer_workers
if diff < 0:
diff = 0
if isinstance(obj, list):
return [map_rank_to_real_device(o) for o in obj]
elif isinstance(obj, int):
return obj + diff
else:
raise ValueError(f'Unsupported type: {obj}')
if kwargs.get('group_name') == 'world':
local_rank = local_rank + node_rank * num_infer_workers
else:
local_rank = map_rank_to_real_device(local_rank - node_rank * num_infer_workers)
rank = dist.get_rank()
if world_size == 1 and [rank] not in group_ranks:
# for ddp inference
group_ranks = [[rank]]
if nnodes > 1 and num_infer_workers < device_count:
"""
Map group_ranks to global ranks
Example:
- Number of nodes (nnodes): 2
- Devices per node (device_count): 4
- Inference workers per node (num_infer_workers): 1
Initial group_ranks:
[[0, 1]]
After mapping to global ranks:
[[0, 3]] # Global ranks corresponding to the local ranks
"""
train_device_count = device_count - num_infer_workers
# vllm.worker.init_distributed_environment
if len(group_ranks) == 1:
group_ranks = group_ranks[0]
for i in range(nnodes):
group_ranks[i * num_infer_workers:(i + 1) * num_infer_workers] = [
train_device_count * i + j for j in range(num_infer_workers)
]
group_ranks = [group_ranks]
# vllm.worker.ensure_model_parallel_initialized
else:
for i in range(nnodes):
for j in range(num_infer_workers):
group_ranks[i * num_infer_workers + j] = [train_device_count * i + j]
return __origin_init__(self, group_ranks, local_rank, *args, **kwargs)
GroupCoordinator.__init__ = __init__
try:
with profiling_patch, restore_torch_device_after_vllm_init():
torch.distributed.get_world_size_origin = torch.distributed.get_world_size
torch.distributed.get_world_size = get_world_size
yield
torch.distributed.get_world_size = torch.distributed.get_world_size_origin
del torch.distributed.get_world_size_origin
finally:
GroupCoordinator.__init__ = __origin_init__
return _get_context() if dist.is_initialized() else nullcontext()
def patch_npu_vllm(vllm_device: str):
if isinstance(vllm_device, int):
vllm_device = get_device(vllm_device)
device_type = vllm_device.split(':')[0]
@contextmanager
def new_group_context():
original_new_group = torch.distributed.new_group
try:
torch.distributed.new_group = partial(original_new_group, use_local_synchronization=True)
torch.npu.mem_get_info = partial(torch.npu.mem_get_info, device=vllm_device)
yield
finally:
torch.distributed.new_group = original_new_group
return new_group_context() if device_type == 'npu' else nullcontext()
@contextmanager
def set_device_context(device: Union[str, int]):
origin_device = get_current_device()
set_device(device)
try:
yield
finally:
set_device(origin_device)
@contextmanager
def restore_torch_device_after_vllm_init():
"""
A context manager to restore the original CUDA device after potential modifications.
This is specifically designed to address an issue in Distributed Data Parallel (DDP)
scenarios where the initialization of the vLLM engine may inadvertently modify the
default CUDA device. The context manager saves the current device at the start and
ensures it is restored upon exit, even if the device is modified within the context.
"""
origin_device = get_current_device()
try:
yield
finally:
current_device = get_current_device()
if origin_device != current_device:
set_device(origin_device)
def patch_vllm_memory_leak():
import vllm
if version.parse(vllm.__version__) != version.parse('0.7.3'):
return
def patch_vllm_abort_seq_group():
from vllm.core.scheduler import Scheduler
from typing import Iterable, Dict
from vllm.sequence import SequenceGroupBase, SequenceGroup, SequenceStatus
def new_abort_seq_group(
self,
request_id: Union[str, Iterable[str]],
seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None,
) -> None:
if isinstance(request_id, str):
request_id = (request_id, )
request_ids = set(request_id)
seq_id_to_seq_group = seq_id_to_seq_group or {}
for state_queue in [self.waiting, self.running, self.swapped]:
aborted_groups: List[SequenceGroup] = []
for seq_group in state_queue:
# When n>1, seq_group.request_id looks like
# foo_parallel_sample_0, while request_ids is just foo, and we
# should resolve it as real_request_id to match.
if seq_group.request_id in seq_id_to_seq_group:
real_request_id = seq_id_to_seq_group[seq_group.request_id].group_id
else:
real_request_id = seq_group.request_id
if real_request_id in request_ids:
# Appending aborted group into pending list.
aborted_groups.append(seq_group)
# We can't remove real_request_id in request_ids here,
# because there may be other seq groups sharing the same
# real_request_id
for aborted_group in aborted_groups:
# Remove the sequence group from the state queue.
state_queue.remove(aborted_group)
# Remove the aborted request from the Mamba cache.
self._finished_requests_ids.append(aborted_group.request_id)
for seq in aborted_group.get_seqs():
if seq.is_finished():
continue
seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq)
if aborted_group.request_id in seq_id_to_seq_group:
del seq_id_to_seq_group[aborted_group.request_id]
self._free_seq_group_cross_attn_blocks(aborted_group)
origin_method = Scheduler.abort_seq_group
Scheduler._old_abort_seq_group = origin_method
Scheduler.abort_seq_group = new_abort_seq_group
def patch_vllm_engine():
from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
from vllm.outputs import PoolingRequestOutput, RequestOutput
from vllm.sequence import ExecuteModelRequest
def new_abort_request(self, request_id) -> None:
for scheduler in self.scheduler:
scheduler.abort_seq_group(request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)
origin_method = LLMEngine.abort_request
LLMEngine._old_abort_request = origin_method
LLMEngine.abort_request = new_abort_request
def new_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
if self.parallel_config.pipeline_parallel_size > 1:
raise NotImplementedError('Pipeline parallelism is only supported through AsyncLLMEngine '
'as performance will be severely degraded otherwise.')
# For llm_engine, there is no pipeline parallel support, so the engine
# used is always 0.
virtual_engine = 0
# These are cached outputs from previous iterations. None if on first
# iteration
cached_outputs = self.cached_scheduler_outputs[virtual_engine]
seq_group_metadata_list = cached_outputs.seq_group_metadata_list
scheduler_outputs = cached_outputs.scheduler_outputs
allow_async_output_proc = cached_outputs.allow_async_output_proc
ctx = self.scheduler_contexts[virtual_engine]
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
# The scheduler is also skipped if a single request caused the last
# engine step to fail, and the previous schedule needs to be rerun.
if not self._has_remaining_steps(seq_group_metadata_list):
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc) = self.scheduler[virtual_engine].schedule()
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
finished_requests_ids = self.scheduler[virtual_engine].get_and_reset_finished_requests_ids()
# When n>1, elements in self.seq_id_to_seq_group should be deleted
# here, otherwise memory leaks.
for finished_request_id in finished_requests_ids:
if finished_request_id in self.seq_id_to_seq_group:
del self.seq_id_to_seq_group[finished_request_id]
# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0):
# cache the scheduler outputs for the next iteration if we have
# lookahead slots
self._cache_scheduler_outputs_for_multi_step(virtual_engine, seq_group_metadata_list,
scheduler_outputs, allow_async_output_proc)
else:
finished_requests_ids = list()
assert seq_group_metadata_list is not None
assert scheduler_outputs is not None
if not scheduler_outputs.is_empty():
# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
# sampled_token_ids, as a separate broadcast over all the PP stages
# will cause one virtual engine's microbatch to block the pipeline.
last_sampled_token_ids = \
self._get_last_sampled_token_ids(virtual_engine)
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=seq_group_metadata_list,
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
finished_requests_ids=finished_requests_ids,
# We use ExecuteModelRequest to pass the last sampled_token_ids
# to each of the non-last PP stages for in-place prepare_input.
last_sampled_token_ids=last_sampled_token_ids)
if allow_async_output_proc:
execute_model_req.async_callback = self.async_callbacks[virtual_engine]
outputs = self.model_executor.execute_model(execute_model_req=execute_model_req)
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
self._update_cached_scheduler_output(virtual_engine, outputs)
else:
# Nothing scheduled => If there is pending async postprocessor,
# then finish it here.
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
# No outputs in this case
outputs = []
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
for seq_group in seq_group_metadata_list:
seq_group.finish_step()
if not self._has_remaining_steps(seq_group_metadata_list):
# clear the cache if we have finished all the steps.
if self.scheduler_config.is_multi_step:
self.cached_scheduler_outputs[0] = SchedulerOutputState()
# is_first_step_output is True only when the num_steps of all
# the sequences are 1. When the num_steps > 1,
# multi_step_model_runner does the first-step output append.
is_first_step_output: bool = False if not seq_group_metadata_list \
else seq_group_metadata_list[0].state.num_steps == 1
# Add results to the output_queue
ctx.append_output(
outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
scheduler_outputs=scheduler_outputs,
is_async=allow_async_output_proc,
is_last_step=True,
is_first_step_output=is_first_step_output)
if outputs and allow_async_output_proc:
assert len(outputs) == 1, ('Async postprocessor expects only a single output set')
self._advance_to_next_step(outputs[0], seq_group_metadata_list,
scheduler_outputs.scheduled_seq_groups)
# Check if need to run the usual non-async path
if not allow_async_output_proc:
self._process_model_outputs(ctx=ctx)
# Log stats.
self.do_log_stats(scheduler_outputs, outputs)
# Tracing
self.do_tracing(scheduler_outputs)
else:
# Multi-step case
return ctx.request_outputs
if not self.has_unfinished_requests():
# Drain async postprocessor (if exists)
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
assert len(ctx.output_queue) == 0
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
self.model_executor.stop_remote_worker_execution_loop()
return ctx.request_outputs
origin_method = LLMEngine.step
LLMEngine._old_step = origin_method
LLMEngine.step = new_step
patch_vllm_abort_seq_group()
patch_vllm_engine()
|