Add option for variable input_size and to add CLS/SEP Tokens

#299
Files changed (1) hide show
  1. geneformer/tokenizer.py +22 -8
geneformer/tokenizer.py CHANGED
@@ -81,14 +81,14 @@ class TranscriptomeTokenizer:
81
  custom_attr_name_dict=None,
82
  nproc=1,
83
  chunk_size=512,
 
 
84
  gene_median_file=GENE_MEDIAN_FILE,
85
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
86
  ):
87
  """
88
  Initialize tokenizer.
89
-
90
  **Parameters:**
91
-
92
  custom_attr_name_dict : None, dict
93
  | Dictionary of custom attributes to be added to the dataset.
94
  | Keys are the names of the attributes in the loom file.
@@ -97,6 +97,10 @@ class TranscriptomeTokenizer:
97
  | Number of processes to use for dataset mapping.
98
  chunk_size: int = 512
99
  | Chunk size for anndata tokenizer.
 
 
 
 
100
  gene_median_file : Path
101
  | Path to pickle file containing dictionary of non-zero median
102
  | gene expression values across Genecorpus-30M.
@@ -112,6 +116,12 @@ class TranscriptomeTokenizer:
112
  # chunk size for anndata tokenizer
113
  self.chunk_size = chunk_size
114
 
 
 
 
 
 
 
115
  # load dictionary of gene normalization factors
116
  # (non-zero median value of expression across Genecorpus-30M)
117
  with open(gene_median_file, "rb") as f:
@@ -137,9 +147,7 @@ class TranscriptomeTokenizer:
137
  ):
138
  """
139
  Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory.
140
-
141
  **Parameters:**
142
-
143
  data_directory : Path
144
  | Path to directory containing loom files or anndata files
145
  output_directory : Path
@@ -324,7 +332,7 @@ class TranscriptomeTokenizer:
324
  file_cell_metadata[k] += subview.ca[k].tolist()
325
  else:
326
  file_cell_metadata = None
327
-
328
  return tokenized_cells, file_cell_metadata
329
 
330
  def create_dataset(
@@ -357,8 +365,14 @@ class TranscriptomeTokenizer:
357
  example["input_ids_uncropped"] = example["input_ids"]
358
  example["length_uncropped"] = len(example["input_ids"])
359
 
360
- # Truncate/Crop input_ids to size 2,048
361
- example["input_ids"] = example["input_ids"][0:2048]
 
 
 
 
 
 
362
  example["length"] = len(example["input_ids"])
363
 
364
  return example
@@ -366,4 +380,4 @@ class TranscriptomeTokenizer:
366
  output_dataset_truncated = output_dataset.map(
367
  format_cell_features, num_proc=self.nproc
368
  )
369
- return output_dataset_truncated
 
81
  custom_attr_name_dict=None,
82
  nproc=1,
83
  chunk_size=512,
84
+ input_size=2048,
85
+ special_token=False,
86
  gene_median_file=GENE_MEDIAN_FILE,
87
  token_dictionary_file=TOKEN_DICTIONARY_FILE,
88
  ):
89
  """
90
  Initialize tokenizer.
 
91
  **Parameters:**
 
92
  custom_attr_name_dict : None, dict
93
  | Dictionary of custom attributes to be added to the dataset.
94
  | Keys are the names of the attributes in the loom file.
 
97
  | Number of processes to use for dataset mapping.
98
  chunk_size: int = 512
99
  | Chunk size for anndata tokenizer.
100
+ input_size: int = 2048
101
+ | Input size for tokenization
102
+ special_token: bool = False
103
+ | Option to add CLS and SEP tokens
104
  gene_median_file : Path
105
  | Path to pickle file containing dictionary of non-zero median
106
  | gene expression values across Genecorpus-30M.
 
116
  # chunk size for anndata tokenizer
117
  self.chunk_size = chunk_size
118
 
119
+ # input size for tokenization
120
+ self.input_size = input_size
121
+
122
+ # add CLS and SEP tokens
123
+ self.special_token = special_token
124
+
125
  # load dictionary of gene normalization factors
126
  # (non-zero median value of expression across Genecorpus-30M)
127
  with open(gene_median_file, "rb") as f:
 
147
  ):
148
  """
149
  Tokenize .loom files in data_directory and save as tokenized .dataset in output_directory.
 
150
  **Parameters:**
 
151
  data_directory : Path
152
  | Path to directory containing loom files or anndata files
153
  output_directory : Path
 
332
  file_cell_metadata[k] += subview.ca[k].tolist()
333
  else:
334
  file_cell_metadata = None
335
+
336
  return tokenized_cells, file_cell_metadata
337
 
338
  def create_dataset(
 
365
  example["input_ids_uncropped"] = example["input_ids"]
366
  example["length_uncropped"] = len(example["input_ids"])
367
 
368
+ # Truncate/Crop input_ids to input size
369
+ if tk.special_token:
370
+ example["input_ids"] = example["input_ids"][0:self.input_size-2] # truncate to leave space for CLS and SEP token
371
+ example["input_ids"] = np.insert(example["input_ids"], 0, self.gene_token_dict.get("<cls>"))
372
+ example["input_ids"] = np.insert(example["input_ids"], len(example["input_ids"]), self.gene_token_dict.get("<sep>"))
373
+ else:
374
+ # Truncate/Crop input_ids to input size
375
+ example["input_ids"] = example["input_ids"][0:self.input_size]
376
  example["length"] = len(example["input_ids"])
377
 
378
  return example
 
380
  output_dataset_truncated = output_dataset.map(
381
  format_cell_features, num_proc=self.nproc
382
  )
383
+ return output_dataset_truncated