File size: 31,282 Bytes
cb2428f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
# Copyright (c) Alibaba, Inc. and its affiliates.

import os
import re
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from functools import partial
from itertools import repeat
from queue import Queue
from typing import List, Optional, Union

import torch
import torch.distributed as dist
from packaging import version
from transformers import GenerationConfig, LogitsProcessor
from transformers.generation.streamers import BaseStreamer

from swift.llm.model.register import fix_do_sample_warning
from swift.utils import get_current_device, get_device, get_device_count, get_node_setting, set_device
from ..protocol import RequestConfig


@dataclass
class AdapterRequest:
    name: str
    path: str


class InferTools:

    @staticmethod
    def _is_chinese_char(cp: int) -> bool:
        """Checks whether CP is the codepoint of a CJK character."""
        # copy from transformers.generation.streamers.TextStreamer
        if ((0x4E00 <= cp <= 0x9FFF) or (0x3400 <= cp <= 0x4DBF) or (0x20000 <= cp <= 0x2A6DF)
                or (0x2A700 <= cp <= 0x2B73F) or (0x2B740 <= cp <= 0x2B81F) or (0x2B820 <= cp <= 0x2CEAF)
                or (0xF900 <= cp <= 0xFAFF) or (0x2F800 <= cp <= 0x2FA1F)):
            return True

        return False


class InferStreamer(InferTools):

    def __init__(self, template, **decode_kwargs):
        self.template = template
        self.tokenizer = template.tokenizer

        self.cache_idx = 0  # token idx
        self.print_idx = 0
        self.decode_kwargs = decode_kwargs
        self.first_num_space = -1  # The number of whitespace characters before the first token.
        self.first_token = True

    def _align_blank_suffix(self, response: str) -> str:
        # Avoid the occurrence of repeated words in sentence.
        cur_num_space = len(response) - len(response.lstrip(' '))
        if self.first_num_space == -1:
            self.first_num_space = cur_num_space
        elif cur_num_space < self.first_num_space:
            response = ' ' * (self.first_num_space - cur_num_space) + response
        elif cur_num_space > self.first_num_space:
            response = response[cur_num_space - self.first_num_space:]
        return response

    def _get_response(self, response: str, is_finished: bool, token_len: int) -> str:
        # After the symbol for a new line, we flush the cache.
        if self.first_token:
            printable_text = response
            self.first_token = False
        elif response.endswith('\n') or is_finished:
            printable_text = response[self.print_idx:]
            self.cache_idx += token_len
            self.first_num_space = -1
            self.print_idx = 0
        # If the last token is a CJK character, we print the characters.
        elif len(response) > 0 and self._is_chinese_char(ord(response[-1])):
            printable_text = response[self.print_idx:]
            self.print_idx += len(printable_text)
        # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
        # which may change with the subsequent token -- there are probably smarter ways to do this!)
        else:
            printable_text = response[self.print_idx:response.rfind(' ') + 1]
            self.print_idx += len(printable_text)
        return printable_text

    def get_printable_text(self, raw_tokens: List[int], is_finished: bool) -> str:
        raw_tokens = raw_tokens[self.cache_idx:]
        if self.first_token:
            raw_tokens = []
        response = self.template.decode(
            raw_tokens, is_finished=is_finished, tokenizer_kwargs=self.decode_kwargs, first_token=self.first_token)
        response = self._align_blank_suffix(response)
        return self._get_response(response, is_finished, len(raw_tokens))


class StreamerMixin:

    def __init__(self):
        self.queue = Queue()

    def __iter__(self):
        return self

    def __next__(self) -> torch.Tensor:
        value = self.queue.get()
        if value is None:
            raise StopIteration()
        else:
            return value


class TokensIteratorStreamer(StreamerMixin, BaseStreamer):

    def put(self, value: torch.Tensor) -> None:
        self.queue.put(value)

    def end(self) -> None:
        self.queue.put(None)


class LogitsStreamer(LogitsProcessor):

    def __init__(self):
        self.queue = Queue()

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        self.queue.put(scores)
        return scores


def _set_generation_config_default_value(model_generation_config: GenerationConfig,
                                         generation_config: GenerationConfig) -> GenerationConfig:
    for k, v in model_generation_config.to_dict().items():
        new_v = getattr(generation_config, k, None)
        if k in ['max_length']:
            continue
        if k in ['no_repeat_ngram_size'] or v is not None and new_v is None:
            setattr(generation_config, k, v)
    return generation_config


def prepare_generation_config(model_generation_config: Optional[GenerationConfig], request_config: RequestConfig,
                              tokenizer) -> Optional[GenerationConfig]:
    if model_generation_config is None or request_config is None:
        return model_generation_config
    kwargs = {'max_new_tokens': request_config.max_tokens}
    # not use: 'n', 'best_of', 'frequency_penalty', 'presence_penalty'
    for key in ['length_penalty']:
        kwargs[key] = getattr(request_config, key)
    for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty', 'num_beams']:
        new_value = getattr(request_config, key)
        if new_value is None:
            kwargs[key] = getattr(model_generation_config, key)
        else:
            kwargs[key] = new_value

    if not model_generation_config.do_sample and request_config.temperature in {0, None}:
        kwargs['temperature'] = 0
    if kwargs['temperature'] == 0:
        kwargs['do_sample'] = False
        kwargs['temperature'] = 1
        kwargs['top_p'] = 1
        kwargs['top_k'] = 50
    else:
        kwargs['do_sample'] = True
    generation_config = GenerationConfig(**kwargs)
    generation_config = _set_generation_config_default_value(model_generation_config, generation_config)
    fix_do_sample_warning(generation_config)

    if generation_config.eos_token_id is None:
        generation_config.eos_token_id = tokenizer.eos_token_id
    generation_config.pad_token_id = tokenizer.pad_token_id
    return generation_config


def patch_lmdeploy(load_weights=False):
    """This patch allows lmdeploy selects device and reload state_dict"""
    import lmdeploy
    assert version.parse(lmdeploy.__version__) >= version.parse('0.7.0')
    from lmdeploy.messages import TurbomindEngineConfig
    from lmdeploy.turbomind.deploy import loader
    from lmdeploy.turbomind.deploy.loader import create_loader
    from lmdeploy.turbomind.deploy.source_model import llama

    def _create_loader(model_path: str, pattern: str):
        if not isinstance(model_path, (str, os.PathLike)):

            def generate():
                generator = OrderedDict()
                model_dict = {}
                if not isinstance(model_path, dict):
                    for key, value in list(model_path):
                        model_dict[key] = value
                else:
                    model_dict = model_path
                for key, value in model_dict.items():
                    match = re.findall(pattern, key)
                    if not match:
                        if -1 not in generator:
                            generator[-1] = {}
                        generator[-1][key] = value
                    else:
                        layer = int(match[0])
                        if layer not in generator:
                            generator[layer] = {}
                        generator[layer][key] = value
                return generator

            return generate()
        else:
            return create_loader(model_path, pattern)

    loader.create_loader = _create_loader
    llama.create_loader = _create_loader

    TurbomindEngineConfig.devices = [0]

    from lmdeploy.turbomind.turbomind import TurboMind
    from lmdeploy.turbomind.utils import ModelSource

    @contextmanager
    def patch_threadpool_map():
        ThreadPoolExecutor.map_origin = ThreadPoolExecutor.map
        ThreadPoolExecutor.map = lambda *args, **kwargs: []
        yield
        ThreadPoolExecutor.map = ThreadPoolExecutor.map_origin
        del ThreadPoolExecutor.map_origin

    @contextmanager
    def tm_model_context(self):

        def _get_tm_model(model_path,
                          model_name,
                          chat_template_name,
                          engine_config: TurbomindEngineConfig,
                          group_size: int = None,
                          out_dir: str = None):
            from lmdeploy.turbomind.deploy.converter import get_tm_model_origin
            tm_model = get_tm_model_origin(model_path, model_name, chat_template_name, engine_config, group_size,
                                           out_dir)
            self.tm_model = tm_model
            return tm_model

        from lmdeploy.turbomind.deploy import converter
        converter.get_tm_model_origin = converter.get_tm_model
        converter.get_tm_model = _get_tm_model
        yield
        converter.get_tm_model = converter.get_tm_model_origin
        del converter.get_tm_model_origin

    def __init__(self,
                 model_path: str,
                 tokenizer: object,
                 model_name: str = None,
                 chat_template_name: str = None,
                 engine_config: TurbomindEngineConfig = None,
                 model_source: ModelSource = ModelSource.WORKSPACE,
                 **kwargs):
        self.gpu_list = engine_config.devices
        with patch_threadpool_map(), tm_model_context(self):
            self.__origin_init__(model_path, tokenizer, model_name, chat_template_name, engine_config, model_source,
                                 **kwargs)

        with ThreadPoolExecutor(max_workers=self.gpu_count) as e:
            ranks = [self.node_id * self.gpu_count + device_id for device_id in range(self.gpu_count)]
            if not load_weights:
                for _ in e.map(self.model_comm.process_weight, self.gpu_list, ranks):
                    pass
            if version.parse(lmdeploy.__version__) < version.parse('0.7.2'):
                for _ in e.map(self.model_comm.create_engine, self.gpu_list, ranks, repeat(self.nccl_params)):
                    pass
            else:
                for _ in e.map(self.model_comm.create_engine, self.gpu_list, ranks):
                    pass

    def _create_weight(self, model_comm):
        """Allocate weight buffer, load params if from_workspace."""

        # TODO: support mpi
        self.node_id = 0
        self.node_num = 1
        if version.parse(lmdeploy.__version__) < version.parse('0.7.2'):
            self.nccl_params = model_comm.create_nccl_params(self.node_id)
        torch.cuda.synchronize()

        # create weight
        def _create_weight_func(index, device_id):
            rank = self.node_id * self.gpu_count + index
            model_comm.create_shared_weights(device_id, rank)

        with ThreadPoolExecutor(max_workers=self.gpu_count) as executor:
            futures = []
            for idx, device_id in enumerate(self.gpu_list):
                futures.append(executor.submit(_create_weight_func, idx, device_id))
            for future in futures:
                future.result()

    def _get_model_params(self, model_comm, tm_params):
        """Get turbomind model params when loading from hf."""

        def _get_params(idx, device_id, que):
            rank = self.node_id * self.gpu_count + idx
            out = model_comm.get_params(device_id, rank)
            que.put(out)

        que = Queue()
        with ThreadPoolExecutor(max_workers=self.gpu_count) as executor:
            futures = []
            for idx, device_id in enumerate(self.gpu_list):
                futures.append(executor.submit(_get_params, idx, device_id, que))
            for future in futures:
                future.result()

        for _ in range(self.gpu_count):
            tensor_map = que.get()
            for k, v in tensor_map.items():
                if k not in tm_params:
                    tm_params[k] = []
                tm_params[k].append(v)

    def _load_weights(self, state_dict):
        tm_params = self.tm_model.tm_params
        self._get_model_params(self.model_comm, tm_params)
        input_model = self.tm_model.input_model
        model_path = input_model.model_path
        input_model.model_path = state_dict
        self.tm_model.export()
        input_model.model_path = model_path

    from lmdeploy.turbomind.turbomind import TurboMindInstance

    def create_instance(self, cuda_stream_id=0):
        return TurboMindInstance(self, self.config, cuda_stream_id, self.gpu_list)

    TurboMind.__origin_init__ = TurboMind.__init__
    TurboMind.__init__ = __init__
    TurboMind._create_weight = _create_weight
    TurboMind._get_model_params = _get_model_params
    TurboMind.create_instance = create_instance
    if load_weights:
        TurboMind.load_weights = _load_weights

    def __init_ins__(self, tm_model, config, cuda_stream_id=0, gpu_list=None):
        if gpu_list is None:
            gpu_list = [0]
        self.gpu_list = gpu_list
        self.__origin_init__(tm_model, config, cuda_stream_id)

    def _create_model_instance(self, device_id):
        model_inst = self.tm_model.model_comm.create_model_instance(self.gpu_list[0])
        return model_inst

    TurboMindInstance.__origin_init__ = TurboMindInstance.__init__
    TurboMindInstance.__init__ = __init_ins__
    TurboMindInstance._create_model_instance = _create_model_instance


def patch_vllm(world_size=1):

    @contextmanager
    def _get_context():
        from vllm.distributed.parallel_state import GroupCoordinator
        from unittest.mock import patch
        try:
            from vllm.worker.worker import Worker
            getattr(Worker, '_assert_memory_footprint_increased_during_profiling')
            profiling_patch = patch(
                'vllm.worker.worker.Worker._assert_memory_footprint_increased_during_profiling', return_value=None)
        except (ImportError, AttributeError):
            profiling_patch = nullcontext()

        __origin_init__ = GroupCoordinator.__init__

        def get_world_size(group=None) -> int:
            if not group:
                # Given size
                return world_size
            else:
                return torch.distributed.get_world_size_origin(group)

        def __init__(self, group_ranks, local_rank, *args, **kwargs):
            node_rank, nnodes = get_node_setting()
            device_count = get_device_count()
            num_infer_workers = world_size // nnodes

            def map_rank_to_real_device(obj):
                # Use the last devices
                # world_size=4 gpus=8 [0,1,2,3] will map to [4,5,6,7]
                diff = device_count - num_infer_workers
                if diff < 0:
                    diff = 0
                if isinstance(obj, list):
                    return [map_rank_to_real_device(o) for o in obj]
                elif isinstance(obj, int):
                    return obj + diff
                else:
                    raise ValueError(f'Unsupported type: {obj}')

            if kwargs.get('group_name') == 'world':
                local_rank = local_rank + node_rank * num_infer_workers
            else:
                local_rank = map_rank_to_real_device(local_rank - node_rank * num_infer_workers)
            rank = dist.get_rank()
            if world_size == 1 and [rank] not in group_ranks:
                # for ddp inference
                group_ranks = [[rank]]
            if nnodes > 1 and num_infer_workers < device_count:
                """
                Map group_ranks to global ranks

                Example:
                  - Number of nodes (nnodes): 2
                  - Devices per node (device_count): 4
                  - Inference workers per node (num_infer_workers): 1

                  Initial group_ranks:
                      [[0, 1]]

                  After mapping to global ranks:
                      [[0, 3]]  # Global ranks corresponding to the local ranks
                """
                train_device_count = device_count - num_infer_workers
                # vllm.worker.init_distributed_environment
                if len(group_ranks) == 1:
                    group_ranks = group_ranks[0]
                    for i in range(nnodes):
                        group_ranks[i * num_infer_workers:(i + 1) * num_infer_workers] = [
                            train_device_count * i + j for j in range(num_infer_workers)
                        ]
                    group_ranks = [group_ranks]
                # vllm.worker.ensure_model_parallel_initialized
                else:
                    for i in range(nnodes):
                        for j in range(num_infer_workers):
                            group_ranks[i * num_infer_workers + j] = [train_device_count * i + j]

            return __origin_init__(self, group_ranks, local_rank, *args, **kwargs)

        GroupCoordinator.__init__ = __init__

        try:
            with profiling_patch, restore_torch_device_after_vllm_init():
                torch.distributed.get_world_size_origin = torch.distributed.get_world_size
                torch.distributed.get_world_size = get_world_size
                yield
                torch.distributed.get_world_size = torch.distributed.get_world_size_origin
                del torch.distributed.get_world_size_origin
        finally:
            GroupCoordinator.__init__ = __origin_init__

    return _get_context() if dist.is_initialized() else nullcontext()


def patch_npu_vllm(vllm_device: str):
    if isinstance(vllm_device, int):
        vllm_device = get_device(vllm_device)
    device_type = vllm_device.split(':')[0]

    @contextmanager
    def new_group_context():
        original_new_group = torch.distributed.new_group
        try:
            torch.distributed.new_group = partial(original_new_group, use_local_synchronization=True)
            torch.npu.mem_get_info = partial(torch.npu.mem_get_info, device=vllm_device)
            yield
        finally:
            torch.distributed.new_group = original_new_group

    return new_group_context() if device_type == 'npu' else nullcontext()


@contextmanager
def set_device_context(device: Union[str, int]):
    origin_device = get_current_device()
    set_device(device)
    try:
        yield
    finally:
        set_device(origin_device)


@contextmanager
def restore_torch_device_after_vllm_init():
    """
    A context manager to restore the original CUDA device after potential modifications.

    This is specifically designed to address an issue in Distributed Data Parallel (DDP)
    scenarios where the initialization of the vLLM engine may inadvertently modify the
    default CUDA device. The context manager saves the current device at the start and
    ensures it is restored upon exit, even if the device is modified within the context.

    """
    origin_device = get_current_device()
    try:
        yield
    finally:
        current_device = get_current_device()
        if origin_device != current_device:
            set_device(origin_device)


def patch_vllm_memory_leak():
    import vllm
    if version.parse(vllm.__version__) != version.parse('0.7.3'):
        return

    def patch_vllm_abort_seq_group():
        from vllm.core.scheduler import Scheduler
        from typing import Iterable, Dict
        from vllm.sequence import SequenceGroupBase, SequenceGroup, SequenceStatus

        def new_abort_seq_group(
            self,
            request_id: Union[str, Iterable[str]],
            seq_id_to_seq_group: Optional[Dict[str, SequenceGroupBase]] = None,
        ) -> None:
            if isinstance(request_id, str):
                request_id = (request_id, )
            request_ids = set(request_id)
            seq_id_to_seq_group = seq_id_to_seq_group or {}
            for state_queue in [self.waiting, self.running, self.swapped]:
                aborted_groups: List[SequenceGroup] = []
                for seq_group in state_queue:
                    # When n>1, seq_group.request_id looks like
                    # foo_parallel_sample_0, while request_ids is just foo, and we
                    # should resolve it as real_request_id to match.
                    if seq_group.request_id in seq_id_to_seq_group:
                        real_request_id = seq_id_to_seq_group[seq_group.request_id].group_id
                    else:
                        real_request_id = seq_group.request_id
                    if real_request_id in request_ids:
                        # Appending aborted group into pending list.
                        aborted_groups.append(seq_group)
                        # We can't remove real_request_id in request_ids here,
                        # because there may be other seq groups sharing the same
                        # real_request_id
                for aborted_group in aborted_groups:
                    # Remove the sequence group from the state queue.
                    state_queue.remove(aborted_group)
                    # Remove the aborted request from the Mamba cache.
                    self._finished_requests_ids.append(aborted_group.request_id)
                    for seq in aborted_group.get_seqs():
                        if seq.is_finished():
                            continue
                        seq.status = SequenceStatus.FINISHED_ABORTED
                        self.free_seq(seq)
                    if aborted_group.request_id in seq_id_to_seq_group:
                        del seq_id_to_seq_group[aborted_group.request_id]

                    self._free_seq_group_cross_attn_blocks(aborted_group)

        origin_method = Scheduler.abort_seq_group
        Scheduler._old_abort_seq_group = origin_method
        Scheduler.abort_seq_group = new_abort_seq_group

    def patch_vllm_engine():
        from vllm.engine.llm_engine import LLMEngine, SchedulerOutputState
        from vllm.outputs import PoolingRequestOutput, RequestOutput
        from vllm.sequence import ExecuteModelRequest

        def new_abort_request(self, request_id) -> None:
            for scheduler in self.scheduler:
                scheduler.abort_seq_group(request_id, seq_id_to_seq_group=self.seq_id_to_seq_group)

        origin_method = LLMEngine.abort_request
        LLMEngine._old_abort_request = origin_method
        LLMEngine.abort_request = new_abort_request

        def new_step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
            if self.parallel_config.pipeline_parallel_size > 1:
                raise NotImplementedError('Pipeline parallelism is only supported through AsyncLLMEngine '
                                          'as performance will be severely degraded otherwise.')

            # For llm_engine, there is no pipeline parallel support, so the engine
            # used is always 0.
            virtual_engine = 0

            # These are cached outputs from previous iterations. None if on first
            # iteration
            cached_outputs = self.cached_scheduler_outputs[virtual_engine]
            seq_group_metadata_list = cached_outputs.seq_group_metadata_list
            scheduler_outputs = cached_outputs.scheduler_outputs
            allow_async_output_proc = cached_outputs.allow_async_output_proc

            ctx = self.scheduler_contexts[virtual_engine]

            # Clear outputs for each new scheduler iteration
            ctx.request_outputs.clear()

            # Skip the scheduler if there are any remaining steps in the seq groups.
            # This ensures that the scheduler is only called again when the current
            # batch has completed.
            # The scheduler is also skipped if a single request caused the last
            # engine step to fail, and the previous schedule needs to be rerun.
            if not self._has_remaining_steps(seq_group_metadata_list):
                # Schedule iteration
                (seq_group_metadata_list, scheduler_outputs,
                 allow_async_output_proc) = self.scheduler[virtual_engine].schedule()

                ctx.seq_group_metadata_list = seq_group_metadata_list
                ctx.scheduler_outputs = scheduler_outputs

                finished_requests_ids = self.scheduler[virtual_engine].get_and_reset_finished_requests_ids()
                # When n>1, elements in self.seq_id_to_seq_group should be deleted
                # here, otherwise memory leaks.
                for finished_request_id in finished_requests_ids:
                    if finished_request_id in self.seq_id_to_seq_group:
                        del self.seq_id_to_seq_group[finished_request_id]

                # Maybe switch from async mode to sync mode
                if not allow_async_output_proc and len(ctx.output_queue) > 0:
                    self._process_model_outputs(ctx=ctx)

                if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0):
                    # cache the scheduler outputs for the next iteration if we have
                    # lookahead slots
                    self._cache_scheduler_outputs_for_multi_step(virtual_engine, seq_group_metadata_list,
                                                                 scheduler_outputs, allow_async_output_proc)
            else:
                finished_requests_ids = list()

            assert seq_group_metadata_list is not None
            assert scheduler_outputs is not None

            if not scheduler_outputs.is_empty():

                # Check if we have a cached last_output from the previous iteration.
                # For supporting PP this is probably the best way to pass the
                # sampled_token_ids, as a separate broadcast over all the PP stages
                # will cause one virtual engine's microbatch to block the pipeline.
                last_sampled_token_ids = \
                    self._get_last_sampled_token_ids(virtual_engine)

                execute_model_req = ExecuteModelRequest(
                    seq_group_metadata_list=seq_group_metadata_list,
                    blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                    blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                    blocks_to_copy=scheduler_outputs.blocks_to_copy,
                    num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                    running_queue_size=scheduler_outputs.running_queue_size,
                    finished_requests_ids=finished_requests_ids,
                    # We use ExecuteModelRequest to pass the last sampled_token_ids
                    # to each of the non-last PP stages for in-place prepare_input.
                    last_sampled_token_ids=last_sampled_token_ids)

                if allow_async_output_proc:
                    execute_model_req.async_callback = self.async_callbacks[virtual_engine]

                outputs = self.model_executor.execute_model(execute_model_req=execute_model_req)

                # We need to do this here so that last step's sampled_token_ids can
                # be passed to the next iteration for PP.
                if self.scheduler_config.is_multi_step:
                    self._update_cached_scheduler_output(virtual_engine, outputs)
            else:
                # Nothing scheduled => If there is pending async postprocessor,
                # then finish it here.
                if len(ctx.output_queue) > 0:
                    self._process_model_outputs(ctx=ctx)
                # No outputs in this case
                outputs = []

            # Finish the current step for all the sequence groups.
            if self.scheduler_config.is_multi_step:
                for seq_group in seq_group_metadata_list:
                    seq_group.finish_step()

            if not self._has_remaining_steps(seq_group_metadata_list):
                # clear the cache if we have finished all the steps.
                if self.scheduler_config.is_multi_step:
                    self.cached_scheduler_outputs[0] = SchedulerOutputState()

                # is_first_step_output is True only when the num_steps of all
                # the sequences are 1. When the num_steps > 1,
                # multi_step_model_runner does the first-step output append.
                is_first_step_output: bool = False if not seq_group_metadata_list \
                    else seq_group_metadata_list[0].state.num_steps == 1

                # Add results to the output_queue
                ctx.append_output(
                    outputs=outputs,
                    seq_group_metadata_list=seq_group_metadata_list,
                    scheduler_outputs=scheduler_outputs,
                    is_async=allow_async_output_proc,
                    is_last_step=True,
                    is_first_step_output=is_first_step_output)

                if outputs and allow_async_output_proc:
                    assert len(outputs) == 1, ('Async postprocessor expects only a single output set')

                    self._advance_to_next_step(outputs[0], seq_group_metadata_list,
                                               scheduler_outputs.scheduled_seq_groups)

                # Check if need to run the usual non-async path
                if not allow_async_output_proc:
                    self._process_model_outputs(ctx=ctx)

                    # Log stats.
                    self.do_log_stats(scheduler_outputs, outputs)

                    # Tracing
                    self.do_tracing(scheduler_outputs)
            else:
                # Multi-step case
                return ctx.request_outputs

            if not self.has_unfinished_requests():
                # Drain async postprocessor (if exists)
                if len(ctx.output_queue) > 0:
                    self._process_model_outputs(ctx=ctx)
                assert len(ctx.output_queue) == 0

                # Stop the execute model loop in parallel workers until there are
                # more requests to process. This avoids waiting indefinitely in
                # torch.distributed ops which may otherwise timeout, and unblocks
                # the RPC thread in the workers so that they can process any other
                # queued control plane messages, such as add/remove lora adapters.
                self.model_executor.stop_remote_worker_execution_loop()

            return ctx.request_outputs

        origin_method = LLMEngine.step
        LLMEngine._old_step = origin_method
        LLMEngine.step = new_step

    patch_vllm_abort_seq_group()
    patch_vllm_engine()