ctheodoris
commited on
Commit
•
916546e
1
Parent(s):
f0ec9ca
adjust logging
Browse files- geneformer/emb_extractor.py +15 -17
geneformer/emb_extractor.py
CHANGED
@@ -75,19 +75,18 @@ def get_embs(
|
|
75 |
if emb_mode == "cls":
|
76 |
assert cls_present, "<cls> token missing in token dictionary"
|
77 |
# Check to make sure that the first token of the filtered input data is cls token
|
78 |
-
for
|
79 |
-
|
80 |
-
cls_token_id = key
|
81 |
assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
|
82 |
else:
|
83 |
if cls_present:
|
84 |
-
logger.warning("CLS token present in token dictionary, excluding from average")
|
85 |
if eos_present:
|
86 |
-
logger.warning("EOS token present in token dictionary, excluding from average")
|
87 |
|
88 |
overall_max_len = 0
|
89 |
|
90 |
-
for i in trange(0, total_batch_length, forward_batch_size, leave
|
91 |
max_range = min(i + forward_batch_size, total_batch_length)
|
92 |
|
93 |
minibatch = filtered_input_data.select([i for i in range(i, max_range)])
|
@@ -163,7 +162,7 @@ def get_embs(
|
|
163 |
|
164 |
|
165 |
if summary_stat is None:
|
166 |
-
if emb_mode == "cell":
|
167 |
embs_stack = torch.cat(embs_list, dim=0)
|
168 |
elif emb_mode == "gene":
|
169 |
embs_stack = pu.pad_tensor_list(
|
@@ -174,8 +173,6 @@ def get_embs(
|
|
174 |
1,
|
175 |
pu.pad_3d_tensor,
|
176 |
)
|
177 |
-
elif emb_mode == "cls":
|
178 |
-
embs_stack = torch.cat(embs_list, dim=0)
|
179 |
|
180 |
# calculate summary stat embs from approximated tdigests
|
181 |
elif summary_stat is not None:
|
@@ -382,7 +379,7 @@ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
|
382 |
bbox_to_anchor=(0.5, 1),
|
383 |
facecolor="white",
|
384 |
)
|
385 |
-
|
386 |
plt.savefig(output_file, bbox_inches="tight")
|
387 |
|
388 |
|
@@ -390,7 +387,7 @@ class EmbExtractor:
|
|
390 |
valid_option_dict = {
|
391 |
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
392 |
"num_classes": {int},
|
393 |
-
"emb_mode": {"
|
394 |
"cell_emb_style": {"mean_pool"},
|
395 |
"gene_emb_style": {"mean_pool"},
|
396 |
"filter_data": {None, dict},
|
@@ -431,10 +428,11 @@ class EmbExtractor:
|
|
431 |
num_classes : int
|
432 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
433 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
434 |
-
emb_mode : {"cell", "gene"}
|
435 |
-
| Whether to output cell or gene embeddings.
|
436 |
-
|
437 |
-
|
|
|
438 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
439 |
gene_emb_style : "mean_pool"
|
440 |
| Method for summarizing gene embeddings.
|
@@ -469,7 +467,7 @@ class EmbExtractor:
|
|
469 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
470 |
| Non-exact is slower but more memory-efficient.
|
471 |
token_dictionary_file : Path
|
472 |
-
| Default is
|
473 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
474 |
|
475 |
**Examples:**
|
@@ -841,4 +839,4 @@ class EmbExtractor:
|
|
841 |
output_file = (
|
842 |
Path(output_directory) / output_prefix_label
|
843 |
).with_suffix(".pdf")
|
844 |
-
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|
|
|
75 |
if emb_mode == "cls":
|
76 |
assert cls_present, "<cls> token missing in token dictionary"
|
77 |
# Check to make sure that the first token of the filtered input data is cls token
|
78 |
+
gene_token_dict = {v:k for k,v in token_gene_dict}
|
79 |
+
cls_token_id = gene_token_dict["<cls>"]
|
|
|
80 |
assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
|
81 |
else:
|
82 |
if cls_present:
|
83 |
+
logger.warning("CLS token present in token dictionary, excluding from average.")
|
84 |
if eos_present:
|
85 |
+
logger.warning("EOS token present in token dictionary, excluding from average.")
|
86 |
|
87 |
overall_max_len = 0
|
88 |
|
89 |
+
for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
|
90 |
max_range = min(i + forward_batch_size, total_batch_length)
|
91 |
|
92 |
minibatch = filtered_input_data.select([i for i in range(i, max_range)])
|
|
|
162 |
|
163 |
|
164 |
if summary_stat is None:
|
165 |
+
if (emb_mode == "cell") or (emb_mode == "cls"):
|
166 |
embs_stack = torch.cat(embs_list, dim=0)
|
167 |
elif emb_mode == "gene":
|
168 |
embs_stack = pu.pad_tensor_list(
|
|
|
173 |
1,
|
174 |
pu.pad_3d_tensor,
|
175 |
)
|
|
|
|
|
176 |
|
177 |
# calculate summary stat embs from approximated tdigests
|
178 |
elif summary_stat is not None:
|
|
|
379 |
bbox_to_anchor=(0.5, 1),
|
380 |
facecolor="white",
|
381 |
)
|
382 |
+
logger.info(f"Output file: {output_file}")
|
383 |
plt.savefig(output_file, bbox_inches="tight")
|
384 |
|
385 |
|
|
|
387 |
valid_option_dict = {
|
388 |
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
389 |
"num_classes": {int},
|
390 |
+
"emb_mode": {"cls", "cell", "gene"},
|
391 |
"cell_emb_style": {"mean_pool"},
|
392 |
"gene_emb_style": {"mean_pool"},
|
393 |
"filter_data": {None, dict},
|
|
|
428 |
num_classes : int
|
429 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
430 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
431 |
+
emb_mode : {"cls", "cell", "gene"}
|
432 |
+
| Whether to output CLS, cell, or gene embeddings.
|
433 |
+
| CLS embeddings are cell embeddings derived from the CLS token in the front of the rank value encoding.
|
434 |
+
cell_emb_style : {"mean_pool"}
|
435 |
+
| Method for summarizing cell embeddings if not using CLS token.
|
436 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
437 |
gene_emb_style : "mean_pool"
|
438 |
| Method for summarizing gene embeddings.
|
|
|
467 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
468 |
| Non-exact is slower but more memory-efficient.
|
469 |
token_dictionary_file : Path
|
470 |
+
| Default is the Geneformer token dictionary
|
471 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
472 |
|
473 |
**Examples:**
|
|
|
839 |
output_file = (
|
840 |
Path(output_directory) / output_prefix_label
|
841 |
).with_suffix(".pdf")
|
842 |
+
plot_heatmap(embs, emb_dims, label, output_file, kwargs_dict)
|