| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Inference script for Vietnamese POS Tagger (TRE-1). |
| |
| Usage: |
| uv run scripts/predict.py "Tôi yêu Việt Nam" |
| uv run scripts/predict.py --version v1.0.0 "Hà Nội là thủ đô" |
| uv run scripts/predict.py --model models/pos_tagger/v1.0.0 "Test" |
| echo "Học sinh đang học bài" | uv run scripts/predict.py - |
| """ |
|
|
| import json |
| import sys |
| import os |
| from pathlib import Path |
|
|
| import click |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| |
| PROJECT_ROOT = Path(__file__).parent.parent |
|
|
| from handler import EndpointHandler |
|
|
|
|
| def get_latest_version(task="pos_tagger"): |
| """Get the latest model version (sorted by timestamp).""" |
| models_dir = PROJECT_ROOT / "models" / task |
| if not models_dir.exists(): |
| return None |
| versions = [d.name for d in models_dir.iterdir() if d.is_dir()] |
| if not versions: |
| return None |
| return sorted(versions)[-1] |
|
|
|
|
| @click.command() |
| @click.argument("text", default="-") |
| @click.option( |
| "--version", "-v", |
| default=None, |
| help="Model version to use (default: latest)", |
| ) |
| @click.option( |
| "--model", "-m", |
| default=None, |
| help="Custom model directory path (overrides version-based path)", |
| ) |
| @click.option( |
| "--format", "-f", |
| "output_format", |
| type=click.Choice(["inline", "json", "conll"]), |
| default="inline", |
| help="Output format", |
| show_default=True, |
| ) |
| def predict(text, version, model, output_format): |
| """Tag Vietnamese text with POS tags. |
| |
| TEXT is the input text to tag. Use '-' to read from stdin. |
| """ |
| |
| if version is None and model is None: |
| version = get_latest_version("pos_tagger") |
| if version is None: |
| raise click.ClickException("No models found in models/pos_tagger/") |
|
|
| |
| if model: |
| model_path = model |
| else: |
| model_path = str(PROJECT_ROOT / "models" / "pos_tagger" / version) |
|
|
| |
| if text == "-": |
| text = sys.stdin.read().strip() |
|
|
| if not text: |
| raise click.ClickException("No input text provided") |
|
|
| |
| handler = EndpointHandler(path=model_path) |
|
|
| |
| result = handler({"inputs": text}) |
|
|
| |
| if output_format == "json": |
| click.echo(json.dumps(result, ensure_ascii=False, indent=2)) |
| elif output_format == "conll": |
| for i, item in enumerate(result, 1): |
| click.echo(f"{i}\t{item['token']}\t{item['tag']}") |
| else: |
| tagged = " ".join(f"{item['token']}/{item['tag']}" for item in result) |
| click.echo(tagged) |
|
|
|
|
| if __name__ == "__main__": |
| predict() |
|
|