|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
from typing import Optional |
|
|
|
|
|
@dataclass |
|
class SentencepieceConfig: |
|
sentencepiece_model: str = field( |
|
default="???", metadata={"help": "path to sentencepiece model"} |
|
) |
|
sentencepiece_enable_sampling: bool = field( |
|
default=False, metadata={"help": "enable sampling"} |
|
) |
|
sentencepiece_alpha: Optional[float] = field( |
|
default=None, metadata={ |
|
"help": "soothing parameter for unigram sampling, " |
|
"and merge probability for BPE-dropout" |
|
} |
|
) |
|
|
|
|
|
class SentencepieceBPE(object): |
|
def __init__(self, cfg): |
|
cfg = SentencepieceConfig(**cfg) |
|
self.enable_sampling = cfg.sentencepiece_enable_sampling |
|
self.alpha = cfg.sentencepiece_alpha |
|
sentencepiece_model = cfg.sentencepiece_model |
|
try: |
|
import sentencepiece as spm |
|
|
|
self.sp = spm.SentencePieceProcessor() |
|
self.sp.Load(sentencepiece_model) |
|
except ImportError: |
|
raise ImportError( |
|
"Please install sentencepiece with: pip install sentencepiece" |
|
) |
|
|
|
def encode(self, x: str) -> str: |
|
return " ".join( |
|
self.sp.Encode( |
|
x, out_type=str, enable_sampling=self.enable_sampling, alpha=self.alpha |
|
) |
|
) |
|
|
|
def decode(self, x: str) -> str: |
|
return x.replace(" ", "").replace("\u2581", " ").strip() |
|
|
|
def is_beginning_of_word(self, x: str) -> bool: |
|
if x in ["<unk>", "<s>", "</s>", "<pad>"]: |
|
|
|
|
|
|
|
|
|
|
|
return True |
|
return x.startswith("\u2581") |
|
|