aletlvl commited on
Commit
6326f97
·
verified ·
1 Parent(s): 833a4d7

Tokenization fixed

Browse files
Files changed (1) hide show
  1. tokenization_nicheformer.py +29 -0
tokenization_nicheformer.py CHANGED
@@ -110,6 +110,7 @@ class NicheformerTokenizer(PreTrainedTokenizer):
110
  aux_tokens: int = 30,
111
  median_counts_per_gene: Optional[np.ndarray] = None,
112
  gene_names: Optional[List[str]] = None,
 
113
  **kwargs
114
  ):
115
  # Initialize base vocabulary
@@ -160,6 +161,25 @@ class NicheformerTokenizer(PreTrainedTokenizer):
160
  self._mask_token = "[MASK]"
161
  self._cls_token = "[CLS]"
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def get_vocab(self) -> Dict[str, int]:
164
  """Returns the vocabulary mapping."""
165
  return self._vocabulary.copy()
@@ -232,6 +252,15 @@ class NicheformerTokenizer(PreTrainedTokenizer):
232
  median_counts += median_counts == 0
233
  x = x / median_counts.reshape((1, -1))
234
 
 
 
 
 
 
 
 
 
 
235
  # Convert to tokens
236
  tokens = _sub_tokenize_data(x, self.max_length, self.aux_tokens)
237
 
 
110
  aux_tokens: int = 30,
111
  median_counts_per_gene: Optional[np.ndarray] = None,
112
  gene_names: Optional[List[str]] = None,
113
+ technology_mean: Optional[Union[str, np.ndarray]] = None,
114
  **kwargs
115
  ):
116
  # Initialize base vocabulary
 
161
  self._mask_token = "[MASK]"
162
  self._cls_token = "[CLS]"
163
 
164
+ # Load technology mean if provided
165
+ self.technology_mean = None
166
+ if technology_mean is not None:
167
+ self._load_technology_mean(technology_mean)
168
+
169
+ def _load_technology_mean(self, technology_mean):
170
+ """Load technology mean from file or array."""
171
+ if isinstance(technology_mean, str):
172
+ try:
173
+ self.technology_mean = np.load(technology_mean)
174
+ print(f"Loaded technology mean from {technology_mean} with shape {self.technology_mean.shape}")
175
+ except Exception as e:
176
+ print(f"Warning: Could not load technology mean from {technology_mean}: {e}")
177
+ elif isinstance(technology_mean, np.ndarray):
178
+ self.technology_mean = technology_mean
179
+ print(f"Using provided technology mean array with shape {self.technology_mean.shape}")
180
+ else:
181
+ print(f"Warning: Invalid technology_mean type: {type(technology_mean)}")
182
+
183
  def get_vocab(self) -> Dict[str, int]:
184
  """Returns the vocabulary mapping."""
185
  return self._vocabulary.copy()
 
252
  median_counts += median_counts == 0
253
  x = x / median_counts.reshape((1, -1))
254
 
255
+ # Apply technology mean normalization if available
256
+ if self.technology_mean is not None and self.technology_mean.shape[0] == x.shape[1]:
257
+ # Avoid division by zero
258
+ safe_mean = np.maximum(self.technology_mean, 1e-6)
259
+ x = x / safe_mean
260
+
261
+ # Apply log1p transformation
262
+ x = np.log1p(x)
263
+
264
  # Convert to tokens
265
  tokens = _sub_tokenize_data(x, self.max_length, self.aux_tokens)
266