aletlvl commited on
Commit
a27a63d
·
verified ·
1 Parent(s): ac2a80e

Tokenization fixed

Browse files
Files changed (1) hide show
  1. tokenization_nicheformer.py +11 -2
tokenization_nicheformer.py CHANGED
@@ -277,7 +277,7 @@ class NicheformerTokenizer(PreTrainedTokenizer):
277
  """
278
  if isinstance(data, ad.AnnData):
279
  adata = data.copy()
280
-
281
  # Align with reference model if available
282
  if hasattr(self, '_load_reference_model'):
283
  reference_model = self._load_reference_model()
@@ -287,11 +287,16 @@ class NicheformerTokenizer(PreTrainedTokenizer):
287
  for col in ['modality', 'specie', 'assay']:
288
  if col in adata.obs.columns:
289
  original_types[col] = adata.obs[col].dtype
 
 
 
 
 
290
 
291
  # Concatenate and then remove the reference
292
  adata = ad.concat([reference_model, adata], join='outer', axis=0)
293
  adata = adata[1:]
294
-
295
  # Restore original column types after concatenation
296
  for col, dtype in original_types.items():
297
  if col in adata.obs.columns:
@@ -299,6 +304,10 @@ class NicheformerTokenizer(PreTrainedTokenizer):
299
  adata.obs[col] = adata.obs[col].astype(dtype)
300
  except Exception as e:
301
  print(f"Warning: Could not convert {col} back to {dtype}: {e}")
 
 
 
 
302
 
303
  # Get gene expression data
304
  X = adata.X
 
277
  """
278
  if isinstance(data, ad.AnnData):
279
  adata = data.copy()
280
+
281
  # Align with reference model if available
282
  if hasattr(self, '_load_reference_model'):
283
  reference_model = self._load_reference_model()
 
287
  for col in ['modality', 'specie', 'assay']:
288
  if col in adata.obs.columns:
289
  original_types[col] = adata.obs[col].dtype
290
+
291
+
292
+ print(f"modality dtype: {adata.obs['modality'].dtype}")
293
+ print(f"specie dtype: {adata.obs['specie'].dtype}")
294
+ print(f"assay dtype: {adata.obs['assay'].dtype}")
295
 
296
  # Concatenate and then remove the reference
297
  adata = ad.concat([reference_model, adata], join='outer', axis=0)
298
  adata = adata[1:]
299
+
300
  # Restore original column types after concatenation
301
  for col, dtype in original_types.items():
302
  if col in adata.obs.columns:
 
304
  adata.obs[col] = adata.obs[col].astype(dtype)
305
  except Exception as e:
306
  print(f"Warning: Could not convert {col} back to {dtype}: {e}")
307
+
308
+ print(f"modality dtype: {adata.obs['modality'].dtype}")
309
+ print(f"specie dtype: {adata.obs['specie'].dtype}")
310
+ print(f"assay dtype: {adata.obs['assay'].dtype}")
311
 
312
  # Get gene expression data
313
  X = adata.X