aletlvl commited on
Commit
ac2a80e
·
verified ·
1 Parent(s): e23f6f2

Tokenization fixed

Browse files
Files changed (1) hide show
  1. tokenization_nicheformer.py +45 -33
tokenization_nicheformer.py CHANGED
@@ -277,61 +277,73 @@ class NicheformerTokenizer(PreTrainedTokenizer):
277
  """
278
  if isinstance(data, ad.AnnData):
279
  adata = data.copy()
280
- print("READING")
281
- print(f"modality dtype: {adata.obs['modality'].dtype}")
282
- print(f"specie dtype: {adata.obs['specie'].dtype}")
283
- print(f"assay dtype: {adata.obs['assay'].dtype}")
284
-
285
  # Align with reference model if available
286
  if hasattr(self, '_load_reference_model'):
287
  reference_model = self._load_reference_model()
288
  if reference_model is not None:
 
 
 
 
 
 
289
  # Concatenate and then remove the reference
290
  adata = ad.concat([reference_model, adata], join='outer', axis=0)
291
  adata = adata[1:]
 
 
 
 
 
 
 
 
292
 
293
- print("AFTER CONCATENATION")
294
- print(f"modality dtype: {adata.obs['modality'].dtype}")
295
- print(f"specie dtype: {adata.obs['specie'].dtype}")
296
- print(f"assay dtype: {adata.obs['assay'].dtype}")
297
  # Get gene expression data
298
  X = adata.X
299
 
300
  # Get metadata for special tokens
301
- # Print column types
302
- print("\nColumn types:")
303
- if 'modality' in adata.obs.columns:
304
- print(f"modality type: {type(adata.obs['modality'])} with dtype: {adata.obs['modality'].dtype}")
305
- if 'specie' in adata.obs.columns:
306
- print(f"specie type: {type(adata.obs['specie'])} with dtype: {adata.obs['specie'].dtype}")
307
- if 'assay' in adata.obs.columns:
308
- print(f"assay type: {type(adata.obs['assay'])} with dtype: {adata.obs['assay'].dtype}")
309
  modality = adata.obs['modality'] if 'modality' in adata.obs.columns else None
310
  species = adata.obs['specie'] if 'specie' in adata.obs.columns else None
311
  technology = adata.obs['assay'] if 'assay' in adata.obs.columns else None
312
 
313
- print(f"Modality: {modality}")
314
- print(f"Species: {species}")
315
- print(f"Technology: {technology}")
316
  # Use integer values directly if available
317
- if modality is not None and pd.api.types.is_numeric_dtype(modality):
318
- modality_tokens = modality.astype(int).tolist()
 
 
 
 
 
 
 
319
  else:
320
- modality_tokens = [self.modality_dict.get(m, self._vocabulary["[PAD]"]) for m in modality] if modality is not None else None
321
 
322
- if species is not None and pd.api.types.is_numeric_dtype(species):
323
- species_tokens = species.astype(int).tolist()
324
- print(f"Species tokens: {species_tokens}")
 
 
 
 
 
 
325
  else:
326
- species_tokens = [self.species_dict.get(s, self._vocabulary["[PAD]"]) for s in species] if species is not None else None
327
- print(f"Species tokens resort: {species_tokens}")
328
 
329
- if technology is not None and pd.api.types.is_numeric_dtype(technology):
330
- technology_tokens = technology.astype(int).tolist()
331
- print(f"Technology tokens: {technology_tokens}")
 
 
 
 
 
 
332
  else:
333
- technology_tokens = [self.technology_dict.get(t, self._vocabulary["[PAD]"]) for t in technology] if technology is not None else None
334
- print(f"Technology tokens resort: {technology_tokens}")
335
  else:
336
  X = data
337
  modality_tokens = None
 
277
  """
278
  if isinstance(data, ad.AnnData):
279
  adata = data.copy()
280
+
 
 
 
 
281
  # Align with reference model if available
282
  if hasattr(self, '_load_reference_model'):
283
  reference_model = self._load_reference_model()
284
  if reference_model is not None:
285
+ # Store original column types before concatenation
286
+ original_types = {}
287
+ for col in ['modality', 'specie', 'assay']:
288
+ if col in adata.obs.columns:
289
+ original_types[col] = adata.obs[col].dtype
290
+
291
  # Concatenate and then remove the reference
292
  adata = ad.concat([reference_model, adata], join='outer', axis=0)
293
  adata = adata[1:]
294
+
295
+ # Restore original column types after concatenation
296
+ for col, dtype in original_types.items():
297
+ if col in adata.obs.columns:
298
+ try:
299
+ adata.obs[col] = adata.obs[col].astype(dtype)
300
+ except Exception as e:
301
+ print(f"Warning: Could not convert {col} back to {dtype}: {e}")
302
 
 
 
 
 
303
  # Get gene expression data
304
  X = adata.X
305
 
306
  # Get metadata for special tokens
 
 
 
 
 
 
 
 
307
  modality = adata.obs['modality'] if 'modality' in adata.obs.columns else None
308
  species = adata.obs['specie'] if 'specie' in adata.obs.columns else None
309
  technology = adata.obs['assay'] if 'assay' in adata.obs.columns else None
310
 
 
 
 
311
  # Use integer values directly if available
312
+ if modality is not None:
313
+ try:
314
+ if pd.api.types.is_numeric_dtype(modality):
315
+ modality_tokens = modality.astype(int).tolist()
316
+ else:
317
+ modality_tokens = [self.modality_dict.get(m, self._vocabulary["[PAD]"]) for m in modality]
318
+ except Exception as e:
319
+ print(f"Warning: Error processing modality tokens: {e}")
320
+ modality_tokens = [self._vocabulary["[PAD]"]] * len(adata)
321
  else:
322
+ modality_tokens = None
323
 
324
+ if species is not None:
325
+ try:
326
+ if pd.api.types.is_numeric_dtype(species):
327
+ species_tokens = species.astype(int).tolist()
328
+ else:
329
+ species_tokens = [self.species_dict.get(s, self._vocabulary["[PAD]"]) for s in species]
330
+ except Exception as e:
331
+ print(f"Warning: Error processing species tokens: {e}")
332
+ species_tokens = [self._vocabulary["[PAD]"]] * len(adata)
333
  else:
334
+ species_tokens = None
 
335
 
336
+ if technology is not None:
337
+ try:
338
+ if pd.api.types.is_numeric_dtype(technology):
339
+ technology_tokens = technology.astype(int).tolist()
340
+ else:
341
+ technology_tokens = [self.technology_dict.get(t, self._vocabulary["[PAD]"]) for t in technology]
342
+ except Exception as e:
343
+ print(f"Warning: Error processing technology tokens: {e}")
344
+ technology_tokens = [self._vocabulary["[PAD]"]] * len(adata)
345
  else:
346
+ technology_tokens = None
 
347
  else:
348
  X = data
349
  modality_tokens = None