kiddothe2b commited on
Commit
9638d60
1 Parent(s): 97ba9b9

Add mandatory register_for_auto_class function for transformers-4.26+ support

Browse files
Files changed (1) hide show
  1. tokenization_hat.py +22 -0
tokenization_hat.py CHANGED
@@ -247,3 +247,25 @@ class HATTokenizer:
247
  torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
248
  ))
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
248
  ))
249
 
250
+ @classmethod
251
+ def register_for_auto_class(cls, auto_class="AutoModel"):
252
+ """
253
+ Register this class with a given auto class. This should only be used for custom models as the ones in the
254
+ library are already mapped with an auto class.
255
+ <Tip warning={true}>
256
+ This API is experimental and may have some slight breaking changes in the next releases.
257
+ </Tip>
258
+ Args:
259
+ auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):
260
+ The auto class to register this new model with.
261
+ """
262
+ if not isinstance(auto_class, str):
263
+ auto_class = auto_class.__name__
264
+
265
+ import transformers.models.auto as auto_module
266
+
267
+ if not hasattr(auto_module, auto_class):
268
+ raise ValueError(f"{auto_class} is not a valid auto class.")
269
+
270
+ cls._auto_class = auto_class
271
+