Update geneformer/perturber_utils.py
Browse filesIf overexpressing & cls is present, shift token to 1st position instead of 0th
geneformer/perturber_utils.py
CHANGED
@@ -218,26 +218,35 @@ def delete_indices(example):
|
|
218 |
|
219 |
|
220 |
# for genes_to_perturb = "all" where only genes within cell are overexpressed
|
221 |
-
def overexpress_indices(example):
|
222 |
indices = example["perturb_index"]
|
223 |
if any(isinstance(el, list) for el in indices):
|
224 |
indices = flatten_list(indices)
|
225 |
for index in sorted(indices, reverse=True):
|
226 |
-
|
|
|
|
|
|
|
227 |
|
228 |
example["length"] = len(example["input_ids"])
|
229 |
return example
|
230 |
|
231 |
|
232 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
233 |
-
def overexpress_tokens(example, max_len):
|
234 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
235 |
if example["perturb_index"] != [-100]:
|
236 |
example = delete_indices(example)
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
# truncate to max input size, must also truncate original emb to be comparable
|
243 |
if len(example["input_ids"]) > max_len:
|
|
|
218 |
|
219 |
|
220 |
# for genes_to_perturb = "all" where only genes within cell are overexpressed
|
221 |
+
def overexpress_indices(example, special_token):
|
222 |
indices = example["perturb_index"]
|
223 |
if any(isinstance(el, list) for el in indices):
|
224 |
indices = flatten_list(indices)
|
225 |
for index in sorted(indices, reverse=True):
|
226 |
+
if special_token:
|
227 |
+
example["input_ids"].insert(1, example["input_ids"].pop(index))
|
228 |
+
else:
|
229 |
+
example["input_ids"].insert(0, example["input_ids"].pop(index))
|
230 |
|
231 |
example["length"] = len(example["input_ids"])
|
232 |
return example
|
233 |
|
234 |
|
235 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
236 |
+
def overexpress_tokens(example, max_len, special_token):
|
237 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
238 |
if example["perturb_index"] != [-100]:
|
239 |
example = delete_indices(example)
|
240 |
+
if special_token:
|
241 |
+
[
|
242 |
+
example["input_ids"].insert(1, token)
|
243 |
+
for token in example["tokens_to_perturb"][::-1]
|
244 |
+
]
|
245 |
+
else:
|
246 |
+
[
|
247 |
+
example["input_ids"].insert(0, token)
|
248 |
+
for token in example["tokens_to_perturb"][::-1]
|
249 |
+
]
|
250 |
|
251 |
# truncate to max input size, must also truncate original emb to be comparable
|
252 |
if len(example["input_ids"]) > max_len:
|