Tokenization fixed
Browse files- tokenization_nicheformer.py +8 -17
tokenization_nicheformer.py
CHANGED
|
@@ -287,11 +287,6 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 287 |
for col in ['modality', 'specie', 'assay']:
|
| 288 |
if col in adata.obs.columns:
|
| 289 |
original_types[col] = adata.obs[col].dtype
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
print(f"modality dtype: {adata.obs['modality'].dtype}")
|
| 293 |
-
print(f"specie dtype: {adata.obs['specie'].dtype}")
|
| 294 |
-
print(f"assay dtype: {adata.obs['assay'].dtype}")
|
| 295 |
|
| 296 |
# Concatenate and then remove the reference
|
| 297 |
adata = ad.concat([reference_model, adata], join='outer', axis=0)
|
|
@@ -304,10 +299,6 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 304 |
adata.obs[col] = adata.obs[col].astype(dtype)
|
| 305 |
except Exception as e:
|
| 306 |
print(f"Warning: Could not convert {col} back to {dtype}: {e}")
|
| 307 |
-
|
| 308 |
-
print(f"modality dtype: {adata.obs['modality'].dtype}")
|
| 309 |
-
print(f"specie dtype: {adata.obs['specie'].dtype}")
|
| 310 |
-
print(f"assay dtype: {adata.obs['assay'].dtype}")
|
| 311 |
|
| 312 |
# Get gene expression data
|
| 313 |
X = adata.X
|
|
@@ -362,20 +353,20 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 362 |
# Tokenize gene expression data
|
| 363 |
token_ids = self._tokenize_gene_expression(X)
|
| 364 |
|
| 365 |
-
# Add special tokens if available
|
| 366 |
special_tokens = np.zeros((token_ids.shape[0], 3), dtype=np.int64)
|
| 367 |
special_token_mask = np.zeros((token_ids.shape[0], 3), dtype=bool)
|
| 368 |
|
| 369 |
-
if modality_tokens is not None:
|
| 370 |
-
special_tokens[:, 0] = modality_tokens
|
| 371 |
-
special_token_mask[:, 0] = True
|
| 372 |
-
|
| 373 |
if species_tokens is not None:
|
| 374 |
-
special_tokens[:,
|
| 375 |
-
special_token_mask[:,
|
| 376 |
|
| 377 |
if technology_tokens is not None:
|
| 378 |
-
special_tokens[:,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
special_token_mask[:, 2] = True
|
| 380 |
|
| 381 |
# Only keep the special tokens that are present (have True in mask)
|
|
|
|
| 287 |
for col in ['modality', 'specie', 'assay']:
|
| 288 |
if col in adata.obs.columns:
|
| 289 |
original_types[col] = adata.obs[col].dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
# Concatenate and then remove the reference
|
| 292 |
adata = ad.concat([reference_model, adata], join='outer', axis=0)
|
|
|
|
| 299 |
adata.obs[col] = adata.obs[col].astype(dtype)
|
| 300 |
except Exception as e:
|
| 301 |
print(f"Warning: Could not convert {col} back to {dtype}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
# Get gene expression data
|
| 304 |
X = adata.X
|
|
|
|
| 353 |
# Tokenize gene expression data
|
| 354 |
token_ids = self._tokenize_gene_expression(X)
|
| 355 |
|
| 356 |
+
# Add special tokens if available - changed order to [species, technology, modality]
|
| 357 |
special_tokens = np.zeros((token_ids.shape[0], 3), dtype=np.int64)
|
| 358 |
special_token_mask = np.zeros((token_ids.shape[0], 3), dtype=bool)
|
| 359 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
if species_tokens is not None:
|
| 361 |
+
special_tokens[:, 0] = species_tokens
|
| 362 |
+
special_token_mask[:, 0] = True
|
| 363 |
|
| 364 |
if technology_tokens is not None:
|
| 365 |
+
special_tokens[:, 1] = technology_tokens
|
| 366 |
+
special_token_mask[:, 1] = True
|
| 367 |
+
|
| 368 |
+
if modality_tokens is not None:
|
| 369 |
+
special_tokens[:, 2] = modality_tokens
|
| 370 |
special_token_mask[:, 2] = True
|
| 371 |
|
| 372 |
# Only keep the special tokens that are present (have True in mask)
|