tokenization_nicheformer.py updated
Browse files
tokenization_nicheformer.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from typing import List, Dict, Optional, Union, Tuple
|
| 2 |
import numpy as np
|
| 3 |
-
from transformers import PreTrainedTokenizer
|
| 4 |
from dataclasses import dataclass
|
| 5 |
import torch
|
| 6 |
import anndata as ad
|
|
@@ -405,7 +405,4 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 405 |
def get_special_tokens_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False) -> List[int]:
|
| 406 |
"""Get list where entries are [1] if a token is [special] else [0]."""
|
| 407 |
# Consider tokens < aux_tokens as special
|
| 408 |
-
return [1 if token_id < self.aux_tokens else 0 for token_id in token_ids_0]
|
| 409 |
-
|
| 410 |
-
# Register the tokenizer class
|
| 411 |
-
AutoTokenizer.register("nicheformer", "tokenization_nicheformer.NicheformerTokenizer")
|
|
|
|
| 1 |
from typing import List, Dict, Optional, Union, Tuple
|
| 2 |
import numpy as np
|
| 3 |
+
from transformers import PreTrainedTokenizer
|
| 4 |
from dataclasses import dataclass
|
| 5 |
import torch
|
| 6 |
import anndata as ad
|
|
|
|
| 405 |
def get_special_tokens_mask(self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False) -> List[int]:
|
| 406 |
"""Get list where entries are [1] if a token is [special] else [0]."""
|
| 407 |
# Consider tokens < aux_tokens as special
|
| 408 |
+
return [1 if token_id < self.aux_tokens else 0 for token_id in token_ids_0]
|
|
|
|
|
|
|
|
|