Update geneformer/tokenizer.py

#450
by hchen725 - opened
Files changed (1) hide show
  1. geneformer/tokenizer.py +42 -36
geneformer/tokenizer.py CHANGED
@@ -103,33 +103,38 @@ def sum_ensembl_ids(
103
  assert (
104
  "ensembl_id_collapsed" not in data.ra.keys()
105
  ), "'ensembl_id_collapsed' column already exists in data.ra.keys()"
 
 
 
 
106
  # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
107
  # Comparing to gene_token_dict here, would not perform any mapping steps
108
- gene_ids_in_dict = [
109
- gene for gene in data.ra.ensembl_id if gene in gene_token_dict.keys()
110
- ]
111
- if collapse_gene_ids is False:
112
-
113
- if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
114
  return data_directory
115
  else:
116
  raise ValueError("Error: data Ensembl IDs non-unique.")
117
-
118
- gene_ids_collapsed = [
119
- gene_mapping_dict.get(gene_id.upper()) for gene_id in data.ra.ensembl_id
120
- ]
121
- gene_ids_collapsed_in_dict = [
122
- gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
123
- ]
124
-
125
- if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
126
- data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
127
  return data_directory
 
128
  else:
129
  dedup_filename = data_directory.with_name(
130
  data_directory.stem + "__dedup.loom"
131
  )
132
- data.ra["ensembl_id_collapsed"] = gene_ids_collapsed
 
133
  dup_genes = [
134
  idx
135
  for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
@@ -204,32 +209,33 @@ def sum_ensembl_ids(
204
  "ensembl_id_collapsed" not in data.var.columns
205
  ), "'ensembl_id_collapsed' column already exists in data.var"
206
 
 
 
 
207
  # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
208
  # Comparing to gene_token_dict here, would not perform any mapping steps
209
- gene_ids_in_dict = [
210
- gene for gene in data.var.ensembl_id if gene in gene_token_dict.keys()
211
- ]
212
- if collapse_gene_ids is False:
213
-
214
- if len(gene_ids_in_dict) == len(set(gene_ids_in_dict)):
215
- return data
216
  else:
217
  raise ValueError("Error: data Ensembl IDs non-unique.")
218
 
219
- # Check for when if collapse_gene_ids is True
220
- gene_ids_collapsed = [
221
- gene_mapping_dict.get(gene_id.upper()) for gene_id in data.var.ensembl_id
222
- ]
223
- gene_ids_collapsed_in_dict = [
224
- gene for gene in gene_ids_collapsed if gene in gene_token_dict.keys()
225
- ]
226
- if len(set(gene_ids_in_dict)) == len(set(gene_ids_collapsed_in_dict)):
227
- data.var["ensembl_id_collapsed"] = data.var.ensembl_id.map(gene_mapping_dict)
228
- return data
229
 
 
 
 
 
 
 
230
  else:
231
- data.var["ensembl_id_collapsed"] = gene_ids_collapsed
232
- data.var_names = gene_ids_collapsed
233
  data = data[:, ~data.var.index.isna()]
234
  dup_genes = [
235
  idx for idx, count in Counter(data.var_names).items() if count > 1
 
103
  assert (
104
  "ensembl_id_collapsed" not in data.ra.keys()
105
  ), "'ensembl_id_collapsed' column already exists in data.ra.keys()"
106
+
107
+
108
+ # Get the ensembl ids that exist in data
109
+ ensembl_ids = data.ra.ensembl_id
110
  # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
111
  # Comparing to gene_token_dict here, would not perform any mapping steps
112
+ if not collapse_gene_ids:
113
+ ensembl_id_check = [
114
+ gene for gene in ensembl_ids if gene in gene_token_dict.keys()
115
+ ]
116
+ if len(ensembl_id_check) == len(set(ensembl_id_check)):
 
117
  return data_directory
118
  else:
119
  raise ValueError("Error: data Ensembl IDs non-unique.")
120
+
121
+ # Get the genes that exist in the mapping dictionary and the value of those genes
122
+ genes_in_map_dict = [gene for gene in ensembl_ids if gene in gene_mapping_dict.keys()]
123
+ vals_from_map_dict = [gene_mapping_dict.get(gene) for gene in genes_in_map_dict]
124
+
125
+ # if the genes in the mapping dict and the value of those genes are of the same length,
126
+ # simply return the mapped values
127
+ if(len(set(genes_in_map_dict)) == len(set(vals_from_map_dict))):
128
+ mapped_vals = [gene_mapping_dict.get(gene.upper()) for gene in data.ra["ensembl_id"]]
129
+ data.ra["ensembl_id_collapsed"] = mapped_vals
130
  return data_directory
131
+ # Genes need to be collapsed
132
  else:
133
  dedup_filename = data_directory.with_name(
134
  data_directory.stem + "__dedup.loom"
135
  )
136
+ mapped_vals = [gene_mapping_dict.get(gene.upper()) for gene in data.ra["ensembl_id"]]
137
+ data.ra["ensembl_id_collapsed"] = mapped_vals
138
  dup_genes = [
139
  idx
140
  for idx, count in Counter(data.ra["ensembl_id_collapsed"]).items()
 
209
  "ensembl_id_collapsed" not in data.var.columns
210
  ), "'ensembl_id_collapsed' column already exists in data.var"
211
 
212
+
213
+ # Get the ensembl ids that exist in data
214
+ ensembl_ids = data.var.ensembl_id
215
  # Check for duplicate Ensembl IDs if collapse_gene_ids is False.
216
  # Comparing to gene_token_dict here, would not perform any mapping steps
217
+ if not collapse_gene_ids:
218
+ ensembl_id_check = [
219
+ gene for gene in ensembl_ids if gene in gene_token_dict.keys()
220
+ ]
221
+ if len(ensembl_id_check) == len(set(ensembl_id_check)):
222
+ return data_directory
 
223
  else:
224
  raise ValueError("Error: data Ensembl IDs non-unique.")
225
 
226
+ # Get the genes that exist in the mapping dictionary and the value of those genes
227
+ genes_in_map_dict = [gene for gene in ensembl_ids if gene in gene_mapping_dict.keys()]
228
+ vals_from_map_dict = [gene_mapping_dict.get(gene) for gene in genes_in_map_dict]
 
 
 
 
 
 
 
229
 
230
+ # if the genes in the mapping dict and the value of those genes are of the same length,
231
+ # simply return the mapped values
232
+ if(len(set(genes_in_map_dict)) == len(set(vals_from_map_dict))):
233
+ data.var["ensembl_id_collapsed"] = data.var.ensembl_id.str.upper().map(gene_mapping_dict)
234
+ return data
235
+ # Genes need to be collapsed
236
  else:
237
+ data.var["ensembl_id_collapsed"] = data.var.ensembl_id.str.upper().map(gene_mapping_dict)
238
+ data.var_names = data.var["ensembl_id_collapsed"]
239
  data = data[:, ~data.var.index.isna()]
240
  dup_genes = [
241
  idx for idx, count in Counter(data.var_names).items() if count > 1