kiddothe2b commited on
Commit
af99e83
1 Parent(s): 9e40d21

Add HAT implementation files

Browse files
Files changed (1) hide show
  1. tokenization_hat.py +7 -2
tokenization_hat.py CHANGED
@@ -12,7 +12,7 @@
12
  # limitations under the License.
13
  """Tokenization classes for HAT."""
14
  import torch
15
- from transformers import AutoTokenizer
16
  from .configuration_hat import HATConfig
17
  from transformers.utils import logging
18
  try:
@@ -92,7 +92,11 @@ class HATTokenizer:
92
 
93
  @classmethod
94
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
95
- return cls(tokenizer=AutoTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs))
 
 
 
 
96
 
97
  def save_pretrained(self, *args, **kwargs):
98
  return self._tokenizer.save_pretrained( *args, **kwargs)
@@ -242,3 +246,4 @@ class HATTokenizer:
242
  flat_input[:chunk_size-1],
243
  torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
244
  ))
 
 
12
  # limitations under the License.
13
  """Tokenization classes for HAT."""
14
  import torch
15
+ from transformers import RobertaTokenizer, BertTokenizer
16
  from .configuration_hat import HATConfig
17
  from transformers.utils import logging
18
  try:
 
92
 
93
  @classmethod
94
  def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
95
+ try:
96
+ tokenizer = RobertaTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
97
+ except:
98
+ tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, **kwargs)
99
+ return cls(tokenizer=tokenizer)
100
 
101
  def save_pretrained(self, *args, **kwargs):
102
  return self._tokenizer.save_pretrained( *args, **kwargs)
 
246
  flat_input[:chunk_size-1],
247
  torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
248
  ))
249
+