aletlvl commited on
Commit
d251649
·
verified ·
1 Parent(s): 1710359

Fix method name in tokenize_anndata method

Browse files
Files changed (1) hide show
  1. tokenization_nicheformer.py +21 -10
tokenization_nicheformer.py CHANGED
@@ -33,11 +33,29 @@ class NicheformerTokenizer(PreTrainedTokenizer):
33
  max_seq_len: Maximum sequence length
34
  aux_tokens: Number of auxiliary tokens reserved
35
  """
 
 
 
 
 
 
 
 
36
  self.max_seq_len = max_seq_len
37
  self.aux_tokens = aux_tokens
38
 
 
 
 
 
 
 
 
 
 
 
39
  # Define token constants to match Nicheformer
40
- self.pad_token_id = 0
41
 
42
  # Define special token mappings
43
  self.modality_dict = {
@@ -98,13 +116,6 @@ class NicheformerTokenizer(PreTrainedTokenizer):
98
 
99
  # Create reverse vocabulary (id to token)
100
  self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
101
-
102
- # Set special tokens for parent class
103
- kwargs["pad_token"] = "<pad>"
104
- kwargs["unk_token"] = "<unk>"
105
- kwargs["mask_token"] = "<mask>"
106
-
107
- super().__init__(**kwargs)
108
 
109
  @property
110
  def vocab_size(self):
@@ -226,7 +237,7 @@ class NicheformerTokenizer(PreTrainedTokenizer):
226
  tokens = self._sub_tokenize_data(X)
227
 
228
  # Create attention mask (1 for real tokens, 0 for padding)
229
- attention_mask = (tokens != self.pad_token_id).astype(np.int32)
230
 
231
  # Extract metadata from obs
232
  result = {
@@ -335,7 +346,7 @@ class NicheformerTokenizer(PreTrainedTokenizer):
335
  tokens = self._sub_tokenize_data(expression_matrix)
336
 
337
  # Create attention mask (1 for real tokens, 0 for padding)
338
- attention_mask = (tokens != self.pad_token_id).astype(np.int32)
339
 
340
  # Add metadata tokens if provided
341
  result = {
 
33
  max_seq_len: Maximum sequence length
34
  aux_tokens: Number of auxiliary tokens reserved
35
  """
36
+ # Initialize the parent class first
37
+ super().__init__(
38
+ pad_token="<pad>",
39
+ eos_token="<eos>",
40
+ unk_token="<unk>",
41
+ **kwargs
42
+ )
43
+
44
  self.max_seq_len = max_seq_len
45
  self.aux_tokens = aux_tokens
46
 
47
+ # Initialize vocabulary
48
+ self.vocab = {}
49
+ self.ids_to_tokens = {}
50
+
51
+ # Load vocabulary if provided
52
+ if vocab_file is not None:
53
+ with open(vocab_file, 'r', encoding='utf-8') as f:
54
+ self.vocab = json.load(f)
55
+ self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
56
+
57
  # Define token constants to match Nicheformer
58
+ self._pad_token_id = 0
59
 
60
  # Define special token mappings
61
  self.modality_dict = {
 
116
 
117
  # Create reverse vocabulary (id to token)
118
  self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
 
 
 
 
 
 
 
119
 
120
  @property
121
  def vocab_size(self):
 
237
  tokens = self._sub_tokenize_data(X)
238
 
239
  # Create attention mask (1 for real tokens, 0 for padding)
240
+ attention_mask = (tokens != self._pad_token_id).astype(np.int32)
241
 
242
  # Extract metadata from obs
243
  result = {
 
346
  tokens = self._sub_tokenize_data(expression_matrix)
347
 
348
  # Create attention mask (1 for real tokens, 0 for padding)
349
+ attention_mask = (tokens != self._pad_token_id).astype(np.int32)
350
 
351
  # Add metadata tokens if provided
352
  result = {