Fix method name in tokenize_anndata method
Browse files- tokenization_nicheformer.py +27 -43
tokenization_nicheformer.py
CHANGED
|
@@ -33,27 +33,27 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 33 |
max_seq_len: Maximum sequence length
|
| 34 |
aux_tokens: Number of auxiliary tokens reserved
|
| 35 |
"""
|
| 36 |
-
# Initialize the parent class first
|
| 37 |
-
super().__init__(
|
| 38 |
-
pad_token="<pad>",
|
| 39 |
-
eos_token="<eos>",
|
| 40 |
-
unk_token="<unk>",
|
| 41 |
-
**kwargs
|
| 42 |
-
)
|
| 43 |
-
|
| 44 |
-
self.max_seq_len = max_seq_len
|
| 45 |
-
self.aux_tokens = aux_tokens
|
| 46 |
-
|
| 47 |
# Initialize vocabulary
|
| 48 |
self.vocab = {}
|
| 49 |
self.ids_to_tokens = {}
|
| 50 |
|
| 51 |
# Load vocabulary if provided
|
| 52 |
-
if vocab_file is not None:
|
| 53 |
with open(vocab_file, 'r', encoding='utf-8') as f:
|
| 54 |
self.vocab = json.load(f)
|
| 55 |
self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# Define token constants to match Nicheformer
|
| 58 |
self._pad_token_id = 0
|
| 59 |
|
|
@@ -86,43 +86,27 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 86 |
"CITE-seq": 17,
|
| 87 |
"Smart-seq v4": 18,
|
| 88 |
}
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
#
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
"<mask>": 2,
|
| 100 |
-
}
|
| 101 |
-
|
| 102 |
# Add modality tokens
|
| 103 |
for token, idx in self.modality_dict.items():
|
| 104 |
-
|
| 105 |
-
|
| 106 |
# Add species tokens
|
| 107 |
for token, idx in self.specie_dict.items():
|
| 108 |
-
|
| 109 |
-
|
| 110 |
# Add technology tokens
|
| 111 |
for token, idx in self.technology_dict.items():
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
# In a real implementation, you would add actual gene names here
|
| 116 |
-
|
| 117 |
-
# Create reverse vocabulary (id to token)
|
| 118 |
-
self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
|
| 119 |
-
|
| 120 |
-
@property
|
| 121 |
-
def vocab_size(self):
|
| 122 |
-
return len(self.vocab)
|
| 123 |
-
|
| 124 |
-
def get_vocab(self):
|
| 125 |
-
return dict(self.vocab)
|
| 126 |
|
| 127 |
def _tokenize(self, text):
|
| 128 |
"""
|
|
|
|
| 33 |
max_seq_len: Maximum sequence length
|
| 34 |
aux_tokens: Number of auxiliary tokens reserved
|
| 35 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# Initialize vocabulary
|
| 37 |
self.vocab = {}
|
| 38 |
self.ids_to_tokens = {}
|
| 39 |
|
| 40 |
# Load vocabulary if provided
|
| 41 |
+
if vocab_file is not None and os.path.isfile(vocab_file):
|
| 42 |
with open(vocab_file, 'r', encoding='utf-8') as f:
|
| 43 |
self.vocab = json.load(f)
|
| 44 |
self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
|
| 45 |
|
| 46 |
+
# Initialize the parent class
|
| 47 |
+
super().__init__(
|
| 48 |
+
pad_token="<pad>",
|
| 49 |
+
eos_token="<eos>",
|
| 50 |
+
unk_token="",
|
| 51 |
+
**kwargs
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
self.max_seq_len = max_seq_len
|
| 55 |
+
self.aux_tokens = aux_tokens
|
| 56 |
+
|
| 57 |
# Define token constants to match Nicheformer
|
| 58 |
self._pad_token_id = 0
|
| 59 |
|
|
|
|
| 86 |
"CITE-seq": 17,
|
| 87 |
"Smart-seq v4": 18,
|
| 88 |
}
|
| 89 |
+
|
| 90 |
+
def get_vocab(self) -> Dict[str, int]:
|
| 91 |
+
"""Return the vocabulary as a dictionary of token to index."""
|
| 92 |
+
if not self.vocab:
|
| 93 |
+
# If vocab is empty, create a minimal vocab with special tokens
|
| 94 |
+
vocab = {}
|
| 95 |
+
# Add special tokens
|
| 96 |
+
vocab["<pad>"] = 0
|
| 97 |
+
vocab["<eos>"] = 1
|
| 98 |
+
vocab[""] = 2
|
|
|
|
|
|
|
|
|
|
| 99 |
# Add modality tokens
|
| 100 |
for token, idx in self.modality_dict.items():
|
| 101 |
+
vocab[token] = idx
|
|
|
|
| 102 |
# Add species tokens
|
| 103 |
for token, idx in self.specie_dict.items():
|
| 104 |
+
vocab[token] = idx
|
|
|
|
| 105 |
# Add technology tokens
|
| 106 |
for token, idx in self.technology_dict.items():
|
| 107 |
+
vocab[token] = idx
|
| 108 |
+
return vocab
|
| 109 |
+
return self.vocab
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
def _tokenize(self, text):
|
| 112 |
"""
|