File size: 2,234 Bytes
8b35dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b272e48
8b35dd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is ported from fairseq:
# https://github.com/facebookresearch/fairseq
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of the fairseq repo

from dataclasses import dataclass, field
from typing import Optional


@dataclass
class SentencepieceConfig:
    sentencepiece_model: str = field(
        default="???", metadata={"help": "path to sentencepiece model"}
    )
    sentencepiece_enable_sampling: bool = field(
        default=False, metadata={"help": "enable sampling"}
    )
    sentencepiece_alpha: Optional[float] = field(
        default=None, metadata={
            "help": "soothing parameter for unigram sampling, "
                    "and merge probability for BPE-dropout"
        }
    )


class SentencepieceBPE(object):
    def __init__(self, cfg):
        cfg = SentencepieceConfig(**cfg)
        self.enable_sampling = cfg.sentencepiece_enable_sampling
        self.alpha = cfg.sentencepiece_alpha
        sentencepiece_model = cfg.sentencepiece_model
        try:
            import sentencepiece as spm

            self.sp = spm.SentencePieceProcessor()
            self.sp.Load(sentencepiece_model)
        except ImportError:
            raise ImportError(
                "Please install sentencepiece with: pip install sentencepiece"
            )

    def encode(self, x: str) -> str:
        return " ".join(
            self.sp.Encode(
                x, out_type=str, enable_sampling=self.enable_sampling, alpha=self.alpha
            )
        )

    def decode(self, x: str) -> str:
        return x.replace(" ", "").replace("\u2581", " ").strip()

    def is_beginning_of_word(self, x: str) -> bool:
        if x in ["<unk>", "<s>", "</s>", "<pad>"]:
            # special elements are always considered beginnings
            # HACK: this logic is already present in fairseq/tasks/masked_lm.py
            # but these special tokens are also contained in the sentencepiece
            # vocabulary which causes duplicate special tokens. This hack makes
            # sure that they are all taken into account.
            return True
        return x.startswith("\u2581")