Christina Theodoris commited on
Commit
67f674c
1 Parent(s): d20ad0a

Add uniform max len for padding for predictions

Browse files
Files changed (1) hide show
  1. examples/gene_classification.ipynb +28 -9
examples/gene_classification.ipynb CHANGED
@@ -139,14 +139,15 @@
139
  "metadata": {},
140
  "outputs": [],
141
  "source": [
142
- "def preprocess_classifier_batch(cell_batch):\n",
143
- " max_batch_len = max([len(i) for i in cell_batch[\"input_ids\"]])\n",
 
144
  " def pad_label_example(example):\n",
145
  " example[\"labels\"] = np.pad(example[\"labels\"], \n",
146
- " (0, max_batch_len-len(example[\"input_ids\"])), \n",
147
  " mode='constant', constant_values=-100)\n",
148
  " example[\"input_ids\"] = np.pad(example[\"input_ids\"], \n",
149
- " (0, max_batch_len-len(example[\"input_ids\"])), \n",
150
  " mode='constant', constant_values=token_dictionary.get(\"<pad>\"))\n",
151
  " example[\"attention_mask\"] = (example[\"input_ids\"] != token_dictionary.get(\"<pad>\")).astype(int)\n",
152
  " return example\n",
@@ -158,10 +159,19 @@
158
  " predict_logits = []\n",
159
  " predict_labels = []\n",
160
  " model.eval()\n",
161
- " for i in range(0, len(evalset), forward_batch_size):\n",
162
- " max_range = min(i+forward_batch_size,len(evalset))\n",
 
 
 
 
 
 
 
 
 
163
  " batch_evalset = evalset.select([i for i in range(i, max_range)])\n",
164
- " padded_batch = preprocess_classifier_batch(batch_evalset)\n",
165
  " padded_batch.set_format(type=\"torch\")\n",
166
  " \n",
167
  " input_data_batch = padded_batch[\"input_ids\"]\n",
@@ -224,7 +234,16 @@
224
  " all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]\n",
225
  " roc_auc = np.sum(all_weighted_roc_auc)\n",
226
  " roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))\n",
227
- " return mean_tpr, roc_auc, roc_auc_sd"
 
 
 
 
 
 
 
 
 
228
  ]
229
  },
230
  {
@@ -327,7 +346,7 @@
327
  " \n",
328
  " # load model\n",
329
  " model = BertForTokenClassification.from_pretrained(\n",
330
- " \"/path/to/pretrained_model/\",\n",
331
  " num_labels=2,\n",
332
  " output_attentions = False,\n",
333
  " output_hidden_states = False\n",
 
139
  "metadata": {},
140
  "outputs": [],
141
  "source": [
142
+ "def preprocess_classifier_batch(cell_batch, max_len):\n",
143
+ " if max_len == None:\n",
144
+ " max_len = max([len(i) for i in cell_batch[\"input_ids\"]])\n",
145
  " def pad_label_example(example):\n",
146
  " example[\"labels\"] = np.pad(example[\"labels\"], \n",
147
+ " (0, max_len-len(example[\"input_ids\"])), \n",
148
  " mode='constant', constant_values=-100)\n",
149
  " example[\"input_ids\"] = np.pad(example[\"input_ids\"], \n",
150
+ " (0, max_len-len(example[\"input_ids\"])), \n",
151
  " mode='constant', constant_values=token_dictionary.get(\"<pad>\"))\n",
152
  " example[\"attention_mask\"] = (example[\"input_ids\"] != token_dictionary.get(\"<pad>\")).astype(int)\n",
153
  " return example\n",
 
159
  " predict_logits = []\n",
160
  " predict_labels = []\n",
161
  " model.eval()\n",
162
+ " \n",
163
+ " # ensure there is at least 2 examples in each batch to avoid incorrect tensor dims\n",
164
+ " evalset_len = len(evalset)\n",
165
+ " max_divisible = find_largest_div(evalset_len, forward_batch_size)\n",
166
+ " if len(evalset) - max_divisible == 1:\n",
167
+ " evalset_len = max_divisible\n",
168
+ " \n",
169
+ " max_evalset_len = max(evalset.select([i for i in range(evalset_len)])[\"length\"])\n",
170
+ " \n",
171
+ " for i in range(0, evalset_len, forward_batch_size):\n",
172
+ " max_range = min(i+forward_batch_size, evalset_len)\n",
173
  " batch_evalset = evalset.select([i for i in range(i, max_range)])\n",
174
+ " padded_batch = preprocess_classifier_batch(batch_evalset, max_evalset_len)\n",
175
  " padded_batch.set_format(type=\"torch\")\n",
176
  " \n",
177
  " input_data_batch = padded_batch[\"input_ids\"]\n",
 
234
  " all_weighted_roc_auc = [a*b for a,b in zip(all_roc_auc, wts)]\n",
235
  " roc_auc = np.sum(all_weighted_roc_auc)\n",
236
  " roc_auc_sd = math.sqrt(np.average((all_roc_auc-roc_auc)**2, weights=wts))\n",
237
+ " return mean_tpr, roc_auc, roc_auc_sd\n",
238
+ "\n",
239
+ "# Function to find the largest number smaller\n",
240
+ "# than or equal to N that is divisible by k\n",
241
+ "def find_largest_div(N, K):\n",
242
+ " rem = N % K\n",
243
+ " if(rem == 0):\n",
244
+ " return N\n",
245
+ " else:\n",
246
+ " return N - rem"
247
  ]
248
  },
249
  {
 
346
  " \n",
347
  " # load model\n",
348
  " model = BertForTokenClassification.from_pretrained(\n",
349
+ " \"/gladstone/theodoris/lab/ctheodoris/archive/geneformer_files/geneformer/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/\",\n",
350
  " num_labels=2,\n",
351
  " output_attentions = False,\n",
352
  " output_hidden_states = False\n",