Update geneformer/emb_extractor.py
Browse files
geneformer/emb_extractor.py
CHANGED
@@ -411,7 +411,7 @@ class EmbExtractor:
|
|
411 |
self,
|
412 |
model_type="Pretrained",
|
413 |
num_classes=0,
|
414 |
-
emb_mode="
|
415 |
cell_emb_style="mean_pool",
|
416 |
gene_emb_style="mean_pool",
|
417 |
filter_data=None,
|
@@ -596,11 +596,11 @@ class EmbExtractor:
|
|
596 |
filtered_input_data = pu.load_and_filter(
|
597 |
self.filter_data, self.nproc, input_data_file
|
598 |
)
|
599 |
-
|
600 |
# Check to make sure that all the labels exist in the tokenized data:
|
601 |
if self.emb_label is not None:
|
602 |
for label in self.emb_label:
|
603 |
-
assert label in
|
604 |
|
605 |
if cell_state is not None:
|
606 |
filtered_input_data = pu.filter_by_dict(
|
@@ -725,6 +725,12 @@ class EmbExtractor:
|
|
725 |
)
|
726 |
raise
|
727 |
|
|
|
|
|
|
|
|
|
|
|
|
|
728 |
state_embs_dict = dict()
|
729 |
state_key = cell_states_to_model["state_key"]
|
730 |
for k, v in cell_states_to_model.items():
|
|
|
411 |
self,
|
412 |
model_type="Pretrained",
|
413 |
num_classes=0,
|
414 |
+
emb_mode="cls",
|
415 |
cell_emb_style="mean_pool",
|
416 |
gene_emb_style="mean_pool",
|
417 |
filter_data=None,
|
|
|
596 |
filtered_input_data = pu.load_and_filter(
|
597 |
self.filter_data, self.nproc, input_data_file
|
598 |
)
|
599 |
+
|
600 |
# Check to make sure that all the labels exist in the tokenized data:
|
601 |
if self.emb_label is not None:
|
602 |
for label in self.emb_label:
|
603 |
+
assert label in filtered_input_data.features.keys(), f"Attribute `{label}` not present in dataset features"
|
604 |
|
605 |
if cell_state is not None:
|
606 |
filtered_input_data = pu.filter_by_dict(
|
|
|
725 |
)
|
726 |
raise
|
727 |
|
728 |
+
if self.emb_label is not None:
|
729 |
+
logger.error(
|
730 |
+
"For extracting state embs, emb_label should be None since labels are based on state embs dict keys."
|
731 |
+
)
|
732 |
+
raise
|
733 |
+
|
734 |
state_embs_dict = dict()
|
735 |
state_key = cell_states_to_model["state_key"]
|
736 |
for k, v in cell_states_to_model.items():
|