aletlvl commited on
Commit
ec49fa4
·
verified ·
1 Parent(s): d251649

Fix method name in tokenize_anndata method

Browse files
Files changed (1) hide show
  1. tokenization_nicheformer.py +27 -43
tokenization_nicheformer.py CHANGED
@@ -33,27 +33,27 @@ class NicheformerTokenizer(PreTrainedTokenizer):
33
  max_seq_len: Maximum sequence length
34
  aux_tokens: Number of auxiliary tokens reserved
35
  """
36
- # Initialize the parent class first
37
- super().__init__(
38
- pad_token="<pad>",
39
- eos_token="<eos>",
40
- unk_token="<unk>",
41
- **kwargs
42
- )
43
-
44
- self.max_seq_len = max_seq_len
45
- self.aux_tokens = aux_tokens
46
-
47
  # Initialize vocabulary
48
  self.vocab = {}
49
  self.ids_to_tokens = {}
50
 
51
  # Load vocabulary if provided
52
- if vocab_file is not None:
53
  with open(vocab_file, 'r', encoding='utf-8') as f:
54
  self.vocab = json.load(f)
55
  self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
56
 
 
 
 
 
 
 
 
 
 
 
 
57
  # Define token constants to match Nicheformer
58
  self._pad_token_id = 0
59
 
@@ -86,43 +86,27 @@ class NicheformerTokenizer(PreTrainedTokenizer):
86
  "CITE-seq": 17,
87
  "Smart-seq v4": 18,
88
  }
89
-
90
- # Create vocabulary
91
- if vocab_file is not None and os.path.isfile(vocab_file):
92
- with open(vocab_file, "r", encoding="utf-8") as f:
93
- self.vocab = json.load(f)
94
- else:
95
- # Create a basic vocabulary with special tokens
96
- self.vocab = {
97
- "<pad>": 0,
98
- "<unk>": 1,
99
- "<mask>": 2,
100
- }
101
-
102
  # Add modality tokens
103
  for token, idx in self.modality_dict.items():
104
- self.vocab[f"<modality_{token}>"] = idx
105
-
106
  # Add species tokens
107
  for token, idx in self.specie_dict.items():
108
- self.vocab[f"<species_{token}>"] = idx
109
-
110
  # Add technology tokens
111
  for token, idx in self.technology_dict.items():
112
- self.vocab[f"<technology_{token}>"] = idx
113
-
114
- # Reserve space for gene tokens (starting from aux_tokens)
115
- # In a real implementation, you would add actual gene names here
116
-
117
- # Create reverse vocabulary (id to token)
118
- self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
119
-
120
- @property
121
- def vocab_size(self):
122
- return len(self.vocab)
123
-
124
- def get_vocab(self):
125
- return dict(self.vocab)
126
 
127
  def _tokenize(self, text):
128
  """
 
33
  max_seq_len: Maximum sequence length
34
  aux_tokens: Number of auxiliary tokens reserved
35
  """
 
 
 
 
 
 
 
 
 
 
 
36
  # Initialize vocabulary
37
  self.vocab = {}
38
  self.ids_to_tokens = {}
39
 
40
  # Load vocabulary if provided
41
+ if vocab_file is not None and os.path.isfile(vocab_file):
42
  with open(vocab_file, 'r', encoding='utf-8') as f:
43
  self.vocab = json.load(f)
44
  self.ids_to_tokens = {v: k for k, v in self.vocab.items()}
45
 
46
+ # Initialize the parent class
47
+ super().__init__(
48
+ pad_token="<pad>",
49
+ eos_token="<eos>",
50
+ unk_token="",
51
+ **kwargs
52
+ )
53
+
54
+ self.max_seq_len = max_seq_len
55
+ self.aux_tokens = aux_tokens
56
+
57
  # Define token constants to match Nicheformer
58
  self._pad_token_id = 0
59
 
 
86
  "CITE-seq": 17,
87
  "Smart-seq v4": 18,
88
  }
89
+
90
+ def get_vocab(self) -> Dict[str, int]:
91
+ """Return the vocabulary as a dictionary of token to index."""
92
+ if not self.vocab:
93
+ # If vocab is empty, create a minimal vocab with special tokens
94
+ vocab = {}
95
+ # Add special tokens
96
+ vocab["<pad>"] = 0
97
+ vocab["<eos>"] = 1
98
+ vocab[""] = 2
 
 
 
99
  # Add modality tokens
100
  for token, idx in self.modality_dict.items():
101
+ vocab[token] = idx
 
102
  # Add species tokens
103
  for token, idx in self.specie_dict.items():
104
+ vocab[token] = idx
 
105
  # Add technology tokens
106
  for token, idx in self.technology_dict.items():
107
+ vocab[token] = idx
108
+ return vocab
109
+ return self.vocab
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  def _tokenize(self, text):
112
  """