|
|
|
|
|
import os |
|
from typing import List |
|
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0' |
|
|
|
|
|
def infer_batch(engine: 'InferEngine', infer_requests: List['InferRequest']): |
|
resp_list = engine.infer(infer_requests) |
|
query0 = infer_requests[0].messages[0]['content'] |
|
query1 = infer_requests[1].messages[0]['content'] |
|
print(f'query0: {query0}') |
|
print(f'response0: {resp_list[0].choices[0].message.content}') |
|
print(f'query1: {query1}') |
|
print(f'response1: {resp_list[1].choices[0].message.content}') |
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
from swift.llm import InferEngine, InferRequest, PtEngine, load_dataset, safe_snapshot_download, BaseArguments |
|
from swift.tuners import Swift |
|
adapter_path = safe_snapshot_download('swift/test_bert') |
|
args = BaseArguments.from_pretrained(adapter_path) |
|
args.max_length = 512 |
|
args.truncation_strategy = 'right' |
|
|
|
model, processor = args.get_model_processor() |
|
model = Swift.from_pretrained(model, adapter_path) |
|
template = args.get_template(processor) |
|
engine = PtEngine.from_model_template(model, template, max_batch_size=64) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset = load_dataset(['DAMO_NLP/jd:cls#1000'], seed=42)[0] |
|
print(f'dataset: {dataset}') |
|
infer_requests = [InferRequest(messages=data['messages']) for data in dataset] |
|
infer_batch(engine, infer_requests) |
|
|
|
infer_batch(engine, [ |
|
InferRequest(messages=[{ |
|
'role': 'user', |
|
'content': '今天天气真好呀' |
|
}]), |
|
InferRequest(messages=[{ |
|
'role': 'user', |
|
'content': '真倒霉' |
|
}]) |
|
]) |
|
|