kiddothe2b
commited on
Commit
•
af99e83
1
Parent(s):
9e40d21
Add HAT implementation files
Browse files- tokenization_hat.py +7 -2
tokenization_hat.py
CHANGED
@@ -12,7 +12,7 @@
|
|
12 |
# limitations under the License.
|
13 |
"""Tokenization classes for HAT."""
|
14 |
import torch
|
15 |
-
from transformers import
|
16 |
from .configuration_hat import HATConfig
|
17 |
from transformers.utils import logging
|
18 |
try:
|
@@ -92,7 +92,11 @@ class HATTokenizer:
|
|
92 |
|
93 |
@classmethod
|
94 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
95 |
-
|
|
|
|
|
|
|
|
|
96 |
|
97 |
def save_pretrained(self, *args, **kwargs):
|
98 |
return self._tokenizer.save_pretrained( *args, **kwargs)
|
@@ -242,3 +246,4 @@ class HATTokenizer:
|
|
242 |
flat_input[:chunk_size-1],
|
243 |
torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
|
244 |
))
|
|
|
|
12 |
# limitations under the License.
|
13 |
"""Tokenization classes for HAT."""
|
14 |
import torch
|
15 |
+
from transformers import RobertaTokenizer, BertTokenizer
|
16 |
from .configuration_hat import HATConfig
|
17 |
from transformers.utils import logging
|
18 |
try:
|
|
|
92 |
|
93 |
@classmethod
|
94 |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
|
95 |
+
try:
|
96 |
+
tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
97 |
+
except:
|
98 |
+
tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
99 |
+
return cls(tokenizer=tokenizer)
|
100 |
|
101 |
def save_pretrained(self, *args, **kwargs):
|
102 |
return self._tokenizer.save_pretrained( *args, **kwargs)
|
|
|
246 |
flat_input[:chunk_size-1],
|
247 |
torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
|
248 |
))
|
249 |
+
|