rubai-corrector-base / test_model.py
islomov's picture
Initial private upload
bfe896d verified
#!/usr/bin/env python3
"""Run example inference for rubai-corrector-base."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
EXAMPLES = [
{
"category": "abbreviation",
"input": "telefon rqami qaysi",
"expected": "Telefon raqami qaysi",
},
{
"category": "apostrophe",
"input": "men ozim kordim",
"expected": "Men o'zim ko'rdim",
},
{
"category": "apostrophe",
"input": "togri yoldan boring",
"expected": "To'g'ri yo'ldan boring",
},
{
"category": "ocr",
"input": "rnen universitetda oqiyrnan",
"expected": "Men universitetda o'qiyman",
},
{
"category": "ocr",
"input": "bu juda rnuhirn masala",
"expected": "Bu juda muhim masala",
},
{
"category": "numbers",
"input": "narxi yigirma besh ming so'm",
"expected": "Narxi 25 000 so'm",
},
{
"category": "numbers",
"input": "uchrashuv o'n beshinchi yanvar kuni",
"expected": "Uchrashuv 15-yanvar kuni",
},
{
"category": "mixed_uz_ru",
"input": "men segodnya bozorga bordim",
"expected": "Men сегодня bozorga bordim",
},
{
"category": "mixed_script",
"input": "privet kak делa",
"expected": "Привет как дела",
},
{
"category": "uzbek_cleanup",
"input": "xamma narsa tayyor",
"expected": "Hamma narsa tayyor",
},
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--model-path",
type=Path,
default=Path(__file__).resolve().parent,
help="Path to the packaged model folder.",
)
parser.add_argument(
"--device",
default="cuda:0" if torch.cuda.is_available() else "cpu",
help="Inference device, for example cuda:0 or cpu.",
)
parser.add_argument(
"--text",
type=str,
default=None,
help="Run a single custom input instead of the built-in example suite.",
)
parser.add_argument(
"--max-new-tokens",
type=int,
default=256,
help="Maximum generation length.",
)
parser.add_argument(
"--json",
action="store_true",
help="Print results as JSON.",
)
return parser.parse_args()
def load_model(model_path: Path, device: str):
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSeq2SeqLM.from_pretrained(model_path)
model.to(device)
model.eval()
return tokenizer, model
def predict(texts: list[str], tokenizer, model, device: str, max_new_tokens: int) -> list[str]:
prompts = [f"correct: {text}" for text in texts]
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
inputs = {name: tensor.to(device) for name, tensor in inputs.items()}
with torch.inference_mode():
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
return tokenizer.batch_decode(output_ids, skip_special_tokens=True)
def main() -> int:
args = parse_args()
tokenizer, model = load_model(args.model_path, args.device)
if args.text is not None:
prediction = predict([args.text], tokenizer, model, args.device, args.max_new_tokens)[0]
if args.json:
print(json.dumps({"input": args.text, "prediction": prediction}, ensure_ascii=False, indent=2))
else:
print(f"Input: {args.text}")
print(f"Prediction: {prediction}")
return 0
predictions = predict(
[example["input"] for example in EXAMPLES],
tokenizer,
model,
args.device,
args.max_new_tokens,
)
results = []
for example, prediction in zip(EXAMPLES, predictions):
results.append(
{
"category": example["category"],
"input": example["input"],
"expected": example["expected"],
"prediction": prediction,
"exact_match": prediction == example["expected"],
}
)
if args.json:
print(json.dumps(results, ensure_ascii=False, indent=2))
return 0
print(f"Model: {args.model_path}")
print(f"Device: {args.device}")
print()
for row in results:
print(f"[{row['category']}]")
print(f"Input: {row['input']}")
print(f"Expected: {row['expected']}")
print(f"Prediction: {row['prediction']}")
print(f"Exact: {row['exact_match']}")
print()
return 0
if __name__ == "__main__":
raise SystemExit(main())