aletlvl commited on
Commit
8ac9ba0
·
verified ·
1 Parent(s): c54fbd2

Tokenization fixed

Browse files
Files changed (1) hide show
  1. tokenization_nicheformer.py +8 -17
tokenization_nicheformer.py CHANGED
@@ -287,11 +287,6 @@ class NicheformerTokenizer(PreTrainedTokenizer):
287
  for col in ['modality', 'specie', 'assay']:
288
  if col in adata.obs.columns:
289
  original_types[col] = adata.obs[col].dtype
290
-
291
-
292
- print(f"modality dtype: {adata.obs['modality'].dtype}")
293
- print(f"specie dtype: {adata.obs['specie'].dtype}")
294
- print(f"assay dtype: {adata.obs['assay'].dtype}")
295
 
296
  # Concatenate and then remove the reference
297
  adata = ad.concat([reference_model, adata], join='outer', axis=0)
@@ -304,10 +299,6 @@ class NicheformerTokenizer(PreTrainedTokenizer):
304
  adata.obs[col] = adata.obs[col].astype(dtype)
305
  except Exception as e:
306
  print(f"Warning: Could not convert {col} back to {dtype}: {e}")
307
-
308
- print(f"modality dtype: {adata.obs['modality'].dtype}")
309
- print(f"specie dtype: {adata.obs['specie'].dtype}")
310
- print(f"assay dtype: {adata.obs['assay'].dtype}")
311
 
312
  # Get gene expression data
313
  X = adata.X
@@ -362,20 +353,20 @@ class NicheformerTokenizer(PreTrainedTokenizer):
362
  # Tokenize gene expression data
363
  token_ids = self._tokenize_gene_expression(X)
364
 
365
- # Add special tokens if available
366
  special_tokens = np.zeros((token_ids.shape[0], 3), dtype=np.int64)
367
  special_token_mask = np.zeros((token_ids.shape[0], 3), dtype=bool)
368
 
369
- if modality_tokens is not None:
370
- special_tokens[:, 0] = modality_tokens
371
- special_token_mask[:, 0] = True
372
-
373
  if species_tokens is not None:
374
- special_tokens[:, 1] = species_tokens
375
- special_token_mask[:, 1] = True
376
 
377
  if technology_tokens is not None:
378
- special_tokens[:, 2] = technology_tokens
 
 
 
 
379
  special_token_mask[:, 2] = True
380
 
381
  # Only keep the special tokens that are present (have True in mask)
 
287
  for col in ['modality', 'specie', 'assay']:
288
  if col in adata.obs.columns:
289
  original_types[col] = adata.obs[col].dtype
 
 
 
 
 
290
 
291
  # Concatenate and then remove the reference
292
  adata = ad.concat([reference_model, adata], join='outer', axis=0)
 
299
  adata.obs[col] = adata.obs[col].astype(dtype)
300
  except Exception as e:
301
  print(f"Warning: Could not convert {col} back to {dtype}: {e}")
 
 
 
 
302
 
303
  # Get gene expression data
304
  X = adata.X
 
353
  # Tokenize gene expression data
354
  token_ids = self._tokenize_gene_expression(X)
355
 
356
+ # Add special tokens if available - changed order to [species, technology, modality]
357
  special_tokens = np.zeros((token_ids.shape[0], 3), dtype=np.int64)
358
  special_token_mask = np.zeros((token_ids.shape[0], 3), dtype=bool)
359
 
 
 
 
 
360
  if species_tokens is not None:
361
+ special_tokens[:, 0] = species_tokens
362
+ special_token_mask[:, 0] = True
363
 
364
  if technology_tokens is not None:
365
+ special_tokens[:, 1] = technology_tokens
366
+ special_token_mask[:, 1] = True
367
+
368
+ if modality_tokens is not None:
369
+ special_tokens[:, 2] = modality_tokens
370
  special_token_mask[:, 2] = True
371
 
372
  # Only keep the special tokens that are present (have True in mask)