|
import os |
|
from typing import Literal |
|
|
|
import torch |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
|
|
|
def _prepare(infer_backend: Literal['vllm', 'pt', 'lmdeploy']): |
|
from swift.llm import InferRequest, get_template |
|
if infer_backend == 'lmdeploy': |
|
from swift.llm import LmdeployEngine |
|
engine = LmdeployEngine('Qwen/Qwen-VL-Chat', torch.float32) |
|
elif infer_backend == 'pt': |
|
from swift.llm import PtEngine |
|
engine = PtEngine('Qwen/Qwen2-VL-7B-Instruct') |
|
elif infer_backend == 'vllm': |
|
from swift.llm import VllmEngine |
|
engine = VllmEngine('Qwen/Qwen2-VL-7B-Instruct') |
|
template = get_template(engine.model_meta.template, engine.processor) |
|
infer_requests = [ |
|
InferRequest([{ |
|
'role': 'user', |
|
'content': '晚上睡不着觉怎么办' |
|
}]), |
|
InferRequest([{ |
|
'role': |
|
'user', |
|
'content': [{ |
|
'type': 'image_url', |
|
'image_url': 'http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/cat.png' |
|
}] |
|
}]) |
|
] |
|
return engine, template, infer_requests |
|
|
|
|
|
def test_infer(engine, template, infer_requests): |
|
from swift.llm import RequestConfig |
|
from swift.plugin import InferStats |
|
request_config = RequestConfig(temperature=0) |
|
infer_stats = InferStats() |
|
|
|
response_list = engine.infer( |
|
infer_requests, template=template, request_config=request_config, metrics=[infer_stats]) |
|
|
|
for response in response_list[:2]: |
|
print(response.choices[0].message.content) |
|
print(infer_stats.compute()) |
|
|
|
|
|
def test_stream(engine, template, infer_requests): |
|
from swift.llm import RequestConfig |
|
from swift.plugin import InferStats |
|
infer_stats = InferStats() |
|
request_config = RequestConfig(temperature=0, stream=True, logprobs=True) |
|
|
|
gen_list = engine.infer(infer_requests, template=template, request_config=request_config, metrics=[infer_stats]) |
|
|
|
for response in gen_list[0]: |
|
if response is None: |
|
continue |
|
print(response.choices[0].delta.content, end='', flush=True) |
|
print() |
|
print(infer_stats.compute()) |
|
|
|
gen_list = engine.infer( |
|
infer_requests, template=template, request_config=request_config, use_tqdm=True, metrics=[infer_stats]) |
|
|
|
for response in gen_list[0]: |
|
pass |
|
|
|
print(infer_stats.compute()) |
|
|
|
|
|
if __name__ == '__main__': |
|
engine, template, infer_requests = _prepare(infer_backend='pt') |
|
test_infer(engine, template, infer_requests) |
|
test_stream(engine, template, infer_requests) |
|
|