aletlvl commited on
Commit
211103a
·
verified ·
1 Parent(s): fadb7a9

Tokenization with model alignment

Browse files
Files changed (1) hide show
  1. tokenization_nicheformer.py +22 -0
tokenization_nicheformer.py CHANGED
@@ -8,6 +8,7 @@ from scipy.sparse import issparse
8
  import numba
9
  import os
10
  import json
 
11
 
12
  # Token IDs must match exactly with the original implementation
13
  PAD_TOKEN = 0
@@ -88,6 +89,19 @@ class NicheformerTokenizer(PreTrainedTokenizer):
88
  species_dict = SPECIES_DICT
89
  technology_dict = TECHNOLOGY_DICT
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  def __init__(
92
  self,
93
  vocab_file=None,
@@ -138,6 +152,7 @@ class NicheformerTokenizer(PreTrainedTokenizer):
138
  self.aux_tokens = aux_tokens
139
  self.median_counts_per_gene = median_counts_per_gene
140
  self.gene_names = gene_names
 
141
 
142
  # Set up special token mappings
143
  self._pad_token = "[PAD]"
@@ -243,6 +258,13 @@ class NicheformerTokenizer(PreTrainedTokenizer):
243
  Dictionary with model inputs
244
  """
245
  if adata is not None:
 
 
 
 
 
 
 
246
  # Get expression matrix
247
  if issparse(adata.X):
248
  x = adata.X.toarray()
 
8
  import numba
9
  import os
10
  import json
11
+ from huggingface_hub import hf_hub_download
12
 
13
  # Token IDs must match exactly with the original implementation
14
  PAD_TOKEN = 0
 
89
  species_dict = SPECIES_DICT
90
  technology_dict = TECHNOLOGY_DICT
91
 
92
+ def _load_reference_model(self):
93
+ """Load reference model for gene alignment."""
94
+ try:
95
+ # Get the model name or path from the tokenizer
96
+ repo_id = self.name_or_path if hasattr(self, "name_or_path") else "aletlvl/Nicheformer"
97
+
98
+ # Download the reference model if not already cached
99
+ model_path = hf_hub_download(repo_id=repo_id, filename="model.h5ad")
100
+ return ad.read_h5ad(model_path)
101
+ except Exception as e:
102
+ print(f"Warning: Could not load reference model: {e}")
103
+ return None
104
+
105
  def __init__(
106
  self,
107
  vocab_file=None,
 
152
  self.aux_tokens = aux_tokens
153
  self.median_counts_per_gene = median_counts_per_gene
154
  self.gene_names = gene_names
155
+ self.name_or_path = kwargs.get('name_or_path', 'aletlvl/Nicheformer')
156
 
157
  # Set up special token mappings
158
  self._pad_token = "[PAD]"
 
258
  Dictionary with model inputs
259
  """
260
  if adata is not None:
261
+ # Align with reference model if needed
262
+ reference_model = self._load_reference_model()
263
+ if reference_model is not None:
264
+ # Concatenate and then remove the reference
265
+ adata = ad.concat([reference_model, adata], join='outer', axis=0)
266
+ adata = adata[1:]
267
+
268
  # Get expression matrix
269
  if issparse(adata.X):
270
  x = adata.X.toarray()