| |
| """ |
| tokenizer/train_sp_tokenizer.py โ SentencePiece Unigram ํ๊ตญ์ด ํ ํฌ๋์ด์ ํ์ต. |
| |
| ํ๊ตญ์ด 1์์ (UTF-8 3๋ฐ์ดํธ) = 1ํ ํฐ์ด ๋๋๋ก Unigram ๋ชจ๋ธ์ ์ฌ์ฉ. |
| character_coverage=0.9995๋ก ํ๊ธ 11,172 ์์ ์ ์ฒด ์ปค๋ฒ. |
| |
| Usage: |
| python tokenizer/train_sp_tokenizer.py \ |
| --input "data/raw/namuwiki_ko/*.txt,data/raw/ko_wiki_0000.txt" \ |
| --vocab_size 64000 \ |
| --output_dir tokenizer/korean_sp |
| |
| Output: |
| tokenizer/korean_sp/tokenizer.model (SentencePiece ๋ชจ๋ธ) |
| tokenizer/korean_sp/tokenizer.vocab (์ดํ ๋ชฉ๋ก) |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import glob |
| import os |
| import sys |
| import tempfile |
| from pathlib import Path |
|
|
|
|
| def expand_inputs(input_spec: str) -> list[str]: |
| """์ฝค๋ง๋ก ๊ตฌ๋ถ๋ ๊ธ๋ก๋ธ ํจํด๋ค์ ์ค์ ํ์ผ ๊ฒฝ๋ก ๋ชฉ๋ก์ผ๋ก ํ์ฅ.""" |
| files: list[str] = [] |
| for pattern in input_spec.split(","): |
| pattern = pattern.strip() |
| if any(c in pattern for c in ("*", "?", "[")): |
| matched = sorted(glob.glob(pattern, recursive=True)) |
| if not matched: |
| print(f"WARNING: ํจํด์ ์ผ์นํ๋ ํ์ผ ์์: {pattern!r}", file=sys.stderr) |
| files.extend(matched) |
| else: |
| if Path(pattern).exists(): |
| files.append(pattern) |
| else: |
| print(f"WARNING: ํ์ผ ์์: {pattern!r}", file=sys.stderr) |
| return files |
|
|
|
|
| def train( |
| input_files: list[str], |
| output_dir: Path, |
| vocab_size: int, |
| num_threads: int, |
| input_sentence_size: int, |
| ) -> None: |
| try: |
| import sentencepiece as spm |
| except ImportError: |
| print( |
| "ERROR: sentencepiece๊ฐ ์ค์น๋์ง ์์.\n" |
| " pip install --break-system-packages sentencepiece", |
| file=sys.stderr, |
| ) |
| sys.exit(1) |
|
|
| output_dir.mkdir(parents=True, exist_ok=True) |
| model_prefix = str(output_dir / "tokenizer") |
|
|
| print(f"์
๋ ฅ ํ์ผ ์: {len(input_files)}") |
| for f in input_files[:5]: |
| print(f" {f}") |
| if len(input_files) > 5: |
| print(f" ... ์ธ {len(input_files) - 5}๊ฐ") |
| print(f"์ดํ ํฌ๊ธฐ: {vocab_size:,}") |
| print(f"์ถ๋ ฅ ๊ฒฝ๋ก: {model_prefix}.model / .vocab") |
| print() |
|
|
| |
| input_str = ",".join(input_files) |
|
|
| spm.SentencePieceTrainer.train( |
| input=input_str, |
| model_prefix=model_prefix, |
| vocab_size=vocab_size, |
| model_type="unigram", |
| character_coverage=0.9995, |
| normalization_rule_name="nfkc", |
| pad_id=0, |
| bos_id=1, |
| eos_id=2, |
| unk_id=3, |
| pad_piece="<pad>", |
| bos_piece="<s>", |
| eos_piece="</s>", |
| unk_piece="<unk>", |
| user_defined_symbols=[], |
| num_threads=num_threads, |
| input_sentence_size=input_sentence_size, |
| shuffle_input_sentence=True, |
| |
| seed_sentencepiece_size=1_000_000, |
| shrinking_factor=0.75, |
| max_sentence_length=4096, |
| ) |
|
|
| model_path = Path(f"{model_prefix}.model") |
| vocab_path = Path(f"{model_prefix}.vocab") |
|
|
| if model_path.exists(): |
| size_mb = model_path.stat().st_size / 1e6 |
| print(f"ํ์ต ์๋ฃ!") |
| print(f" ๋ชจ๋ธ: {model_path} ({size_mb:.1f} MB)") |
| print(f" ์ดํ: {vocab_path}") |
| print() |
| print("๋ค์ ๋จ๊ณ:") |
| print(f" python tokenizer/convert_sp_to_hf.py \\") |
| print(f" --model {model_path} \\") |
| print(f" --output {output_dir}/tokenizer.json") |
| else: |
| print("ERROR: ํ์ต ์คํจ โ ์ถ๋ ฅ ํ์ผ์ด ์์ฑ๋์ง ์์", file=sys.stderr) |
| sys.exit(1) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="SentencePiece Unigram ํ๊ตญ์ด ํ ํฌ๋์ด์ ํ์ต", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
| parser.add_argument( |
| "--input", |
| required=True, |
| help="์ฝค๋ง๋ก ๊ตฌ๋ถ๋ ํ์ผ/๊ธ๋ก๋ธ ํจํด (์: 'data/raw/ko/*.txt,data/raw/wiki.txt')", |
| ) |
| parser.add_argument( |
| "--vocab_size", |
| type=int, |
| default=64000, |
| help="์ดํ ํฌ๊ธฐ", |
| ) |
| parser.add_argument( |
| "--output_dir", |
| type=Path, |
| default=Path("tokenizer/korean_sp"), |
| help="๋ชจ๋ธ ์ ์ฅ ๋๋ ํ ๋ฆฌ", |
| ) |
| parser.add_argument( |
| "--num_threads", |
| type=int, |
| default=64, |
| help="ํ์ต์ ์ฌ์ฉํ CPU ์ค๋ ๋ ์", |
| ) |
| parser.add_argument( |
| "--input_sentence_size", |
| type=int, |
| default=10_000_000, |
| help="ํ์ต์ ์ฌ์ฉํ ์ต๋ ๋ฌธ์ฅ ์ (0 = ๋ฌด์ ํ)", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| input_files = expand_inputs(args.input) |
| if not input_files: |
| print("ERROR: ์
๋ ฅ ํ์ผ์ด ์์ต๋๋ค.", file=sys.stderr) |
| sys.exit(1) |
| train( |
| input_files=input_files, |
| output_dir=args.output_dir, |
| vocab_size=args.vocab_size, |
| num_threads=args.num_threads, |
| input_sentence_size=args.input_sentence_size, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|