|
import os |
|
from typing import Literal |
|
|
|
import torch |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
|
|
|
def _prepare(infer_backend: Literal['vllm', 'pt', 'lmdeploy']): |
|
from swift.llm import InferRequest, get_template |
|
if infer_backend == 'lmdeploy': |
|
from swift.llm import LmdeployEngine |
|
engine = LmdeployEngine('OpenGVLab/InternVL2_5-2B', torch.float32) |
|
elif infer_backend == 'pt': |
|
from swift.llm import PtEngine |
|
engine = PtEngine('Qwen/Qwen2-7B-Instruct', max_batch_size=16) |
|
elif infer_backend == 'vllm': |
|
from swift.llm import VllmEngine |
|
engine = VllmEngine('Qwen/Qwen2-7B-Instruct') |
|
template = get_template(engine.model_meta.template, engine.processor) |
|
infer_requests = [ |
|
|
|
InferRequest([{ |
|
'role': 'user', |
|
'content': 'hello! who are you' |
|
}]) for i in range(100) |
|
] |
|
return engine, template, infer_requests |
|
|
|
|
|
def test_infer(infer_backend): |
|
from swift.llm import RequestConfig |
|
from swift.plugin import InferStats |
|
engine, template, infer_requests = _prepare(infer_backend=infer_backend) |
|
request_config = RequestConfig(temperature=0) |
|
infer_stats = InferStats() |
|
|
|
response_list = engine.infer( |
|
infer_requests, template=template, request_config=request_config, metrics=[infer_stats]) |
|
|
|
for response in response_list[:2]: |
|
print(response.choices[0].message.content) |
|
print(infer_stats.compute()) |
|
|
|
|
|
def test_stream(infer_backend): |
|
from swift.llm import RequestConfig |
|
from swift.plugin import InferStats |
|
engine, template, infer_requests = _prepare(infer_backend=infer_backend) |
|
infer_stats = InferStats() |
|
request_config = RequestConfig(temperature=0, stream=True, logprobs=True) |
|
|
|
gen_list = engine.infer(infer_requests, template=template, request_config=request_config, metrics=[infer_stats]) |
|
|
|
for response in gen_list[0]: |
|
if response is None: |
|
continue |
|
print(response.choices[0].delta.content, end='', flush=True) |
|
print() |
|
print(infer_stats.compute()) |
|
|
|
gen_list = engine.infer( |
|
infer_requests, template=template, request_config=request_config, use_tqdm=True, metrics=[infer_stats]) |
|
|
|
for response in gen_list[0]: |
|
pass |
|
|
|
print(infer_stats.compute()) |
|
|
|
|
|
if __name__ == '__main__': |
|
test_infer('pt') |
|
|
|
|