Tokenization fixed
Browse files- tokenization_nicheformer.py +29 -0
tokenization_nicheformer.py
CHANGED
|
@@ -110,6 +110,7 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 110 |
aux_tokens: int = 30,
|
| 111 |
median_counts_per_gene: Optional[np.ndarray] = None,
|
| 112 |
gene_names: Optional[List[str]] = None,
|
|
|
|
| 113 |
**kwargs
|
| 114 |
):
|
| 115 |
# Initialize base vocabulary
|
|
@@ -160,6 +161,25 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 160 |
self._mask_token = "[MASK]"
|
| 161 |
self._cls_token = "[CLS]"
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
def get_vocab(self) -> Dict[str, int]:
|
| 164 |
"""Returns the vocabulary mapping."""
|
| 165 |
return self._vocabulary.copy()
|
|
@@ -232,6 +252,15 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 232 |
median_counts += median_counts == 0
|
| 233 |
x = x / median_counts.reshape((1, -1))
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
# Convert to tokens
|
| 236 |
tokens = _sub_tokenize_data(x, self.max_length, self.aux_tokens)
|
| 237 |
|
|
|
|
| 110 |
aux_tokens: int = 30,
|
| 111 |
median_counts_per_gene: Optional[np.ndarray] = None,
|
| 112 |
gene_names: Optional[List[str]] = None,
|
| 113 |
+
technology_mean: Optional[Union[str, np.ndarray]] = None,
|
| 114 |
**kwargs
|
| 115 |
):
|
| 116 |
# Initialize base vocabulary
|
|
|
|
| 161 |
self._mask_token = "[MASK]"
|
| 162 |
self._cls_token = "[CLS]"
|
| 163 |
|
| 164 |
+
# Load technology mean if provided
|
| 165 |
+
self.technology_mean = None
|
| 166 |
+
if technology_mean is not None:
|
| 167 |
+
self._load_technology_mean(technology_mean)
|
| 168 |
+
|
| 169 |
+
def _load_technology_mean(self, technology_mean):
|
| 170 |
+
"""Load technology mean from file or array."""
|
| 171 |
+
if isinstance(technology_mean, str):
|
| 172 |
+
try:
|
| 173 |
+
self.technology_mean = np.load(technology_mean)
|
| 174 |
+
print(f"Loaded technology mean from {technology_mean} with shape {self.technology_mean.shape}")
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"Warning: Could not load technology mean from {technology_mean}: {e}")
|
| 177 |
+
elif isinstance(technology_mean, np.ndarray):
|
| 178 |
+
self.technology_mean = technology_mean
|
| 179 |
+
print(f"Using provided technology mean array with shape {self.technology_mean.shape}")
|
| 180 |
+
else:
|
| 181 |
+
print(f"Warning: Invalid technology_mean type: {type(technology_mean)}")
|
| 182 |
+
|
| 183 |
def get_vocab(self) -> Dict[str, int]:
|
| 184 |
"""Returns the vocabulary mapping."""
|
| 185 |
return self._vocabulary.copy()
|
|
|
|
| 252 |
median_counts += median_counts == 0
|
| 253 |
x = x / median_counts.reshape((1, -1))
|
| 254 |
|
| 255 |
+
# Apply technology mean normalization if available
|
| 256 |
+
if self.technology_mean is not None and self.technology_mean.shape[0] == x.shape[1]:
|
| 257 |
+
# Avoid division by zero
|
| 258 |
+
safe_mean = np.maximum(self.technology_mean, 1e-6)
|
| 259 |
+
x = x / safe_mean
|
| 260 |
+
|
| 261 |
+
# Apply log1p transformation
|
| 262 |
+
x = np.log1p(x)
|
| 263 |
+
|
| 264 |
# Convert to tokens
|
| 265 |
tokens = _sub_tokenize_data(x, self.max_length, self.aux_tokens)
|
| 266 |
|