|
import json |
|
|
|
from vllm import LLM, SamplingParams |
|
|
|
template = """### USER: |
|
|
|
請將以下路名解析為 JSON 格式。 |
|
|
|
輸入:臺北市中正區八德路 |
|
輸出:{{"city": "臺北市", "town": "中正區", "road": "八德路"}} |
|
|
|
輸入:{} |
|
|
|
### ASSISTANT: |
|
|
|
{}""" |
|
|
|
template = """<|im_start|>user |
|
請將以下路名解析為 JSON 格式。 |
|
|
|
輸入:{} |
|
<|im_end|> |
|
<|im_start|>assistant |
|
{}""" |
|
|
|
|
|
|
|
|
|
|
|
def build_prompt(inn, out=""): |
|
return template.format(inn, out) |
|
|
|
|
|
def iter_dataset(file_path): |
|
data = load_json(file_path) |
|
|
|
for item in data: |
|
city = item["city"] |
|
town = item["town"] |
|
road = item["road"] |
|
|
|
full = f"{city}{town}{road}" |
|
|
|
yield full, item |
|
|
|
|
|
def load_json(file_path): |
|
with open(file_path, "rt", encoding="UTF-8") as fp: |
|
return json.load(fp) |
|
|
|
|
|
|
|
prompts, items = list(), list() |
|
for full, item in iter_dataset("data/test.json"): |
|
prompt = build_prompt(full) |
|
prompts.append(prompt) |
|
items.append(item) |
|
|
|
|
|
model_name = "models/Llama-7B-TwAddr-Merged" |
|
llm = LLM(model_name, dtype="float16") |
|
|
|
|
|
|
|
sampling_params = SamplingParams( |
|
max_tokens=256, |
|
temperature=0.0, |
|
stop=["}"], |
|
) |
|
|
|
|
|
outputs = llm.generate(prompts, sampling_params) |
|
|
|
|
|
results = list() |
|
for out, item in zip(outputs, items): |
|
text = out.outputs[0].text |
|
|
|
|
|
try: |
|
begin = text.index("{") |
|
text = text[begin:] + "}" |
|
pred = json.loads(text) |
|
except: |
|
pred = None |
|
|
|
results.append(pred == item) |
|
if pred != item: |
|
print(pred, item) |
|
|
|
|
|
accuracy = sum(results) / len(results) |
|
print(f"Accuracy: {accuracy:.2%}") |
|
|