victan commited on
Commit
452004c
1 Parent(s): e52cb05

Upload seamless_communication/models/tokenizer.py with huggingface_hub

Browse files
seamless_communication/models/tokenizer.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Optional, Sequence, Set, final
8
+
9
+ from fairseq2.data.text import (
10
+ SentencePieceDecoder,
11
+ SentencePieceEncoder,
12
+ SentencePieceModel,
13
+ TextTokenDecoder,
14
+ TextTokenEncoder,
15
+ TextTokenizer,
16
+ vocab_info_from_sentencepiece,
17
+ )
18
+ from fairseq2.data.typing import PathLike
19
+ from fairseq2.typing import Device, finaloverride
20
+
21
+
22
+ @final
23
+ class SPMTokenizer(TextTokenizer):
24
+ """Represents standard SPM-based tokenizer used in MT tasks"""
25
+
26
+ model: SentencePieceModel
27
+ langs: Set[str]
28
+ prepend_target_langtok_to_target: bool
29
+
30
+ def __init__(
31
+ self,
32
+ pathname: PathLike,
33
+ langs: Sequence[str],
34
+ prepend_target_langtok_to_target: bool = True,
35
+ ) -> None:
36
+ """
37
+ :param pathname:
38
+ The pathname of the SentencePiece model file.
39
+ :param langs:
40
+ The list of supported languages.
41
+ :param default_lang:
42
+ The fall-back language if no language is specified.
43
+ """
44
+ self.langs = set(langs)
45
+ self.prepend_target_langtok_to_target = prepend_target_langtok_to_target
46
+
47
+ # Each language is represented by a `__lang__` control symbol.
48
+ control_symbols = [self._lang_tok_to_internal(lang) for lang in sorted(langs)]
49
+ self.model = SentencePieceModel(pathname, control_symbols)
50
+ vocab_info = vocab_info_from_sentencepiece(self.model)
51
+ super().__init__(vocab_info)
52
+
53
+ @classmethod
54
+ def _lang_tok_to_internal(cls, lang: str) -> str:
55
+ return f"__{lang}__"
56
+
57
+ @finaloverride
58
+ def create_encoder(
59
+ self,
60
+ *,
61
+ task: Optional[str] = None,
62
+ lang: Optional[str] = None,
63
+ mode: Optional[str] = None,
64
+ device: Optional[Device] = None,
65
+ pin_memory: bool = False,
66
+ ) -> TextTokenEncoder:
67
+ """Create a token encoder.
68
+
69
+ :param task:
70
+ Must be 'translation'. If ``None``, defaults to 'translation'.
71
+ :param lang:
72
+ A language from :attr:`langs`. If ``None``, defaults to
73
+ :attr:`default_lang`.
74
+ :param mode:
75
+ Must be 'source' or 'target'.
76
+ :param device:
77
+ The device on which to construct tensors.
78
+ :param pin_memory:
79
+ If ``True``, uses pinned memory while constructing tensors.
80
+ """
81
+ if task is not None and task != "translation":
82
+ raise ValueError(f"`task` must be 'translation', but is '{task}' instead.")
83
+
84
+ assert lang is not None
85
+
86
+ if lang not in self.langs:
87
+ raise ValueError(
88
+ f"`lang` must be a supported language, but is '{lang}' instead."
89
+ )
90
+
91
+ if mode is None or mode == "source":
92
+ prefix_tokens = []
93
+ suffix_tokens = ["</s>"]
94
+ elif mode == "target":
95
+ prefix_tokens = (
96
+ ["</s>"] + [self._lang_tok_to_internal(lang)]
97
+ if self.prepend_target_langtok_to_target
98
+ else []
99
+ )
100
+ suffix_tokens = ["</s>"]
101
+ else:
102
+ raise ValueError(
103
+ f"`mode` must be 'source' or 'target', but is '{mode}' instead."
104
+ )
105
+
106
+ return SentencePieceEncoder(
107
+ self.model,
108
+ prefix_tokens=prefix_tokens,
109
+ suffix_tokens=suffix_tokens,
110
+ device=device,
111
+ pin_memory=pin_memory,
112
+ )
113
+
114
+ @finaloverride
115
+ def create_raw_encoder(
116
+ self, *, device: Optional[Device] = None, pin_memory: bool = False
117
+ ) -> TextTokenEncoder:
118
+ return SentencePieceEncoder(self.model, device=device, pin_memory=pin_memory)
119
+
120
+ @finaloverride
121
+ def create_decoder(self) -> TextTokenDecoder:
122
+ return SentencePieceDecoder(self.model)