| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Prediction script for Vietnamese Word Segmentation. |
| |
| Uses underthesea regex_tokenize to split text into syllables, |
| then applies CRF model at syllable level to decide word boundaries. |
| |
| Usage: |
| uv run scripts/predict_word_segmentation.py "Trên thế giới, giá vàng đang giao dịch" |
| echo "Text here" | uv run scripts/predict_word_segmentation.py - |
| """ |
|
|
| import sys |
| from pathlib import Path |
|
|
| import click |
| import pycrfsuite |
| from underthesea.pipeline.word_tokenize.regex_tokenize import tokenize as regex_tokenize |
|
|
|
|
| def get_syllable_at(syllables, position, offset): |
| """Get syllable at position + offset, with boundary handling.""" |
| idx = position + offset |
| if idx < 0: |
| return "__BOS__" |
| elif idx >= len(syllables): |
| return "__EOS__" |
| return syllables[idx] |
|
|
|
|
| def is_punct(s): |
| """Check if string is punctuation.""" |
| return len(s) == 1 and not s.isalnum() |
|
|
|
|
| def load_dictionary(path): |
| """Load dictionary from a text file (one word per line).""" |
| dictionary = set() |
| with open(path, encoding="utf-8") as f: |
| for line in f: |
| line = line.strip() |
| if line: |
| dictionary.add(line) |
| return dictionary |
|
|
|
|
| def extract_syllable_features(syllables, position, dictionary=None): |
| """Extract features for a syllable at given position.""" |
| features = {} |
|
|
| |
| s0 = get_syllable_at(syllables, position, 0) |
| is_boundary = s0 in ("__BOS__", "__EOS__") |
|
|
| features["S[0]"] = s0 |
| features["S[0].lower"] = s0.lower() if not is_boundary else s0 |
| features["S[0].istitle"] = str(s0.istitle()) if not is_boundary else "False" |
| features["S[0].isupper"] = str(s0.isupper()) if not is_boundary else "False" |
| features["S[0].isdigit"] = str(s0.isdigit()) if not is_boundary else "False" |
| features["S[0].ispunct"] = str(is_punct(s0)) if not is_boundary else "False" |
| features["S[0].len"] = str(len(s0)) if not is_boundary else "0" |
| features["S[0].prefix2"] = s0[:2] if not is_boundary and len(s0) >= 2 else s0 |
| features["S[0].suffix2"] = s0[-2:] if not is_boundary and len(s0) >= 2 else s0 |
|
|
| |
| s_1 = get_syllable_at(syllables, position, -1) |
| s_2 = get_syllable_at(syllables, position, -2) |
| features["S[-1]"] = s_1 |
| features["S[-1].lower"] = s_1.lower() if s_1 not in ("__BOS__", "__EOS__") else s_1 |
| features["S[-2]"] = s_2 |
| features["S[-2].lower"] = s_2.lower() if s_2 not in ("__BOS__", "__EOS__") else s_2 |
|
|
| |
| s1 = get_syllable_at(syllables, position, 1) |
| s2 = get_syllable_at(syllables, position, 2) |
| features["S[1]"] = s1 |
| features["S[1].lower"] = s1.lower() if s1 not in ("__BOS__", "__EOS__") else s1 |
| features["S[2]"] = s2 |
| features["S[2].lower"] = s2.lower() if s2 not in ("__BOS__", "__EOS__") else s2 |
|
|
| |
| features["S[-1,0]"] = f"{s_1}|{s0}" |
| features["S[0,1]"] = f"{s0}|{s1}" |
|
|
| |
| features["S[-1,0,1]"] = f"{s_1}|{s0}|{s1}" |
|
|
| |
| if dictionary is not None: |
| n = len(syllables) |
|
|
| if position >= 1: |
| match = "" |
| for length in range(2, min(6, position + 2)): |
| start = position - length + 1 |
| if start >= 0: |
| ngram = " ".join(syllables[start:position + 1]).lower() |
| if ngram in dictionary: |
| match = ngram |
| features["S[-1,0].in_dict"] = match if match else "0" |
|
|
| if position < n - 1: |
| match = "" |
| for length in range(2, min(6, n - position + 1)): |
| ngram = " ".join(syllables[position:position + length]).lower() |
| if ngram in dictionary: |
| match = ngram |
| features["S[0,1].in_dict"] = match if match else "0" |
|
|
| return features |
|
|
|
|
| def sentence_to_syllable_features(syllables, dictionary=None): |
| """Convert syllable sequence to feature sequences.""" |
| return [ |
| [f"{k}={v}" for k, v in extract_syllable_features(syllables, i, dictionary).items()] |
| for i in range(len(syllables)) |
| ] |
|
|
|
|
| def labels_to_words(syllables, labels): |
| """Convert syllable sequence and BIO labels back to words.""" |
| words = [] |
| current_word = [] |
|
|
| for syl, label in zip(syllables, labels): |
| if label == "B": |
| if current_word: |
| words.append(" ".join(current_word)) |
| current_word = [syl] |
| else: |
| current_word.append(syl) |
|
|
| if current_word: |
| words.append(" ".join(current_word)) |
|
|
| return words |
|
|
|
|
| def segment_text(text, tagger, dictionary=None): |
| """ |
| Full pipeline: regex tokenize -> CRF segment -> output words. |
| """ |
| |
| syllables = regex_tokenize(text) |
|
|
| if not syllables: |
| return "" |
|
|
| |
| X = sentence_to_syllable_features(syllables, dictionary) |
|
|
| |
| labels = tagger.tag(X) |
|
|
| |
| words = labels_to_words(syllables, labels) |
|
|
| return "_".join(words).replace(" ", "_").replace("_", " ").replace(" ", " _ ") |
|
|
|
|
| def segment_text_formatted(text, tagger, use_underscore=True, dictionary=None): |
| """ |
| Full pipeline with formatted output. |
| """ |
| syllables = regex_tokenize(text) |
|
|
| if not syllables: |
| return "" |
|
|
| X = sentence_to_syllable_features(syllables, dictionary) |
| labels = tagger.tag(X) |
| words = labels_to_words(syllables, labels) |
|
|
| if use_underscore: |
| |
| return " ".join(w.replace(" ", "_") for w in words) |
| else: |
| return " ".join(words) |
|
|
|
|
| @click.command() |
| @click.argument("text", required=False) |
| @click.option( |
| "--model", "-m", |
| default="word_segmenter.crfsuite", |
| help="Path to CRF model file", |
| show_default=True, |
| ) |
| @click.option( |
| "--underscore/--no-underscore", |
| default=True, |
| help="Use underscore to join compound word syllables", |
| ) |
| def main(text, model, underscore): |
| """Segment Vietnamese text into words.""" |
| |
| if text == "-" or text is None: |
| text = sys.stdin.read().strip() |
|
|
| if not text: |
| click.echo("No input text provided", err=True) |
| return |
|
|
| |
| if model.endswith(".crf"): |
| |
| try: |
| from underthesea_core import CRFModel, CRFTagger |
| except ImportError: |
| from underthesea_core.underthesea_core import CRFModel, CRFTagger |
| crf_model = CRFModel.load(model) |
| tagger = CRFTagger.from_model(crf_model) |
| else: |
| |
| tagger = pycrfsuite.Tagger() |
| tagger.open(model) |
|
|
| |
| model_dir = Path(model).parent |
| dict_path = model_dir / "dictionary.txt" |
| dictionary = load_dictionary(dict_path) if dict_path.exists() else None |
| if dictionary: |
| click.echo(f"Dictionary: {len(dictionary)} words", err=True) |
|
|
| |
| for line in text.split("\n"): |
| if line.strip(): |
| result = segment_text_formatted(line, tagger, use_underscore=underscore, dictionary=dictionary) |
| click.echo(result) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|