igorktech commited on
Commit
0d45d88
1 Parent(s): c14782f

Upload 3 files

Browse files
Files changed (3) hide show
  1. configuration_hat.py +1 -1
  2. modelling_hat.py +0 -1
  3. tokenization_hat.py +22 -0
configuration_hat.py CHANGED
@@ -147,4 +147,4 @@ class HATOnnxConfig(OnnxConfig):
147
  ("input_ids", {0: "batch", 1: "sequence"}),
148
  ("attention_mask", {0: "batch", 1: "sequence"}),
149
  ]
150
- )
 
147
  ("input_ids", {0: "batch", 1: "sequence"}),
148
  ("attention_mask", {0: "batch", 1: "sequence"}),
149
  ]
150
+ )
modelling_hat.py CHANGED
@@ -2357,4 +2357,3 @@ def off_diagonal(x):
2357
  assert n == m
2358
  return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
2359
 
2360
-
 
2357
  assert n == m
2358
  return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()
2359
 
 
tokenization_hat.py CHANGED
@@ -246,4 +246,26 @@ class HATTokenizer:
246
  flat_input[:chunk_size-1],
247
  torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
248
  ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
 
246
  flat_input[:chunk_size-1],
247
  torch.tensor([self.pad_token_id] * max(0, chunk_size - len(flat_input) - 1), dtype=torch.int)
248
  ))
249
+
250
+ @classmethod
251
+ def register_for_auto_class(cls, auto_class="AutoModel"):
252
+ """
253
+ Register this class with a given auto class. This should only be used for custom models as the ones in the
254
+ library are already mapped with an auto class.
255
+ <Tip warning={true}>
256
+ This API is experimental and may have some slight breaking changes in the next releases.
257
+ </Tip>
258
+ Args:
259
+ auto_class (`str` or `type`, *optional*, defaults to `"TFAutoModel"`):
260
+ The auto class to register this new model with.
261
+ """
262
+ if not isinstance(auto_class, str):
263
+ auto_class = auto_class.__name__
264
+
265
+ import transformers.models.auto as auto_module
266
+
267
+ if not hasattr(auto_module, auto_class):
268
+ raise ValueError(f"{auto_class} is not a valid auto class.")
269
+
270
+ cls._auto_class = auto_class
271