Tokenization fixed
Browse files- tokenization_nicheformer.py +11 -2
tokenization_nicheformer.py
CHANGED
|
@@ -277,7 +277,7 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 277 |
"""
|
| 278 |
if isinstance(data, ad.AnnData):
|
| 279 |
adata = data.copy()
|
| 280 |
-
|
| 281 |
# Align with reference model if available
|
| 282 |
if hasattr(self, '_load_reference_model'):
|
| 283 |
reference_model = self._load_reference_model()
|
|
@@ -287,11 +287,16 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 287 |
for col in ['modality', 'specie', 'assay']:
|
| 288 |
if col in adata.obs.columns:
|
| 289 |
original_types[col] = adata.obs[col].dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
# Concatenate and then remove the reference
|
| 292 |
adata = ad.concat([reference_model, adata], join='outer', axis=0)
|
| 293 |
adata = adata[1:]
|
| 294 |
-
|
| 295 |
# Restore original column types after concatenation
|
| 296 |
for col, dtype in original_types.items():
|
| 297 |
if col in adata.obs.columns:
|
|
@@ -299,6 +304,10 @@ class NicheformerTokenizer(PreTrainedTokenizer):
|
|
| 299 |
adata.obs[col] = adata.obs[col].astype(dtype)
|
| 300 |
except Exception as e:
|
| 301 |
print(f"Warning: Could not convert {col} back to {dtype}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
|
| 303 |
# Get gene expression data
|
| 304 |
X = adata.X
|
|
|
|
| 277 |
"""
|
| 278 |
if isinstance(data, ad.AnnData):
|
| 279 |
adata = data.copy()
|
| 280 |
+
|
| 281 |
# Align with reference model if available
|
| 282 |
if hasattr(self, '_load_reference_model'):
|
| 283 |
reference_model = self._load_reference_model()
|
|
|
|
| 287 |
for col in ['modality', 'specie', 'assay']:
|
| 288 |
if col in adata.obs.columns:
|
| 289 |
original_types[col] = adata.obs[col].dtype
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
print(f"modality dtype: {adata.obs['modality'].dtype}")
|
| 293 |
+
print(f"specie dtype: {adata.obs['specie'].dtype}")
|
| 294 |
+
print(f"assay dtype: {adata.obs['assay'].dtype}")
|
| 295 |
|
| 296 |
# Concatenate and then remove the reference
|
| 297 |
adata = ad.concat([reference_model, adata], join='outer', axis=0)
|
| 298 |
adata = adata[1:]
|
| 299 |
+
|
| 300 |
# Restore original column types after concatenation
|
| 301 |
for col, dtype in original_types.items():
|
| 302 |
if col in adata.obs.columns:
|
|
|
|
| 304 |
adata.obs[col] = adata.obs[col].astype(dtype)
|
| 305 |
except Exception as e:
|
| 306 |
print(f"Warning: Could not convert {col} back to {dtype}: {e}")
|
| 307 |
+
|
| 308 |
+
print(f"modality dtype: {adata.obs['modality'].dtype}")
|
| 309 |
+
print(f"specie dtype: {adata.obs['specie'].dtype}")
|
| 310 |
+
print(f"assay dtype: {adata.obs['assay'].dtype}")
|
| 311 |
|
| 312 |
# Get gene expression data
|
| 313 |
X = adata.X
|