hchen725 commited on
Commit
6207453
·
verified ·
1 Parent(s): 2c8d3f5

Update geneformer/emb_extractor.py

Browse files
Files changed (1) hide show
  1. geneformer/emb_extractor.py +9 -3
geneformer/emb_extractor.py CHANGED
@@ -411,7 +411,7 @@ class EmbExtractor:
411
  self,
412
  model_type="Pretrained",
413
  num_classes=0,
414
- emb_mode="cell",
415
  cell_emb_style="mean_pool",
416
  gene_emb_style="mean_pool",
417
  filter_data=None,
@@ -596,11 +596,11 @@ class EmbExtractor:
596
  filtered_input_data = pu.load_and_filter(
597
  self.filter_data, self.nproc, input_data_file
598
  )
599
-
600
  # Check to make sure that all the labels exist in the tokenized data:
601
  if self.emb_label is not None:
602
  for label in self.emb_label:
603
- assert label in list(filtered_input_data.features), f"Attribute `{label}` not present in dataset features"
604
 
605
  if cell_state is not None:
606
  filtered_input_data = pu.filter_by_dict(
@@ -725,6 +725,12 @@ class EmbExtractor:
725
  )
726
  raise
727
 
 
 
 
 
 
 
728
  state_embs_dict = dict()
729
  state_key = cell_states_to_model["state_key"]
730
  for k, v in cell_states_to_model.items():
 
411
  self,
412
  model_type="Pretrained",
413
  num_classes=0,
414
+ emb_mode="cls",
415
  cell_emb_style="mean_pool",
416
  gene_emb_style="mean_pool",
417
  filter_data=None,
 
596
  filtered_input_data = pu.load_and_filter(
597
  self.filter_data, self.nproc, input_data_file
598
  )
599
+
600
  # Check to make sure that all the labels exist in the tokenized data:
601
  if self.emb_label is not None:
602
  for label in self.emb_label:
603
+ assert label in filtered_input_data.features.keys(), f"Attribute `{label}` not present in dataset features"
604
 
605
  if cell_state is not None:
606
  filtered_input_data = pu.filter_by_dict(
 
725
  )
726
  raise
727
 
728
+ if self.emb_label is not None:
729
+ logger.error(
730
+ "For extracting state embs, emb_label should be None since labels are based on state embs dict keys."
731
+ )
732
+ raise
733
+
734
  state_embs_dict = dict()
735
  state_key = cell_states_to_model["state_key"]
736
  for k, v in cell_states_to_model.items():