{ "cells": [ { "cell_type": "markdown", "id": "234afff3", "metadata": {}, "source": [ "## Geneformer Fine-Tuning for Cell Annotation Application" ] }, { "cell_type": "code", "execution_count": 2, "id": "1cbe6178-ea4d-478a-80a8-65ffaa4c1820", "metadata": {}, "outputs": [], "source": [ "import os\n", "GPU_NUMBER = [0]\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([str(s) for s in GPU_NUMBER])\n", "os.environ[\"NCCL_DEBUG\"] = \"INFO\"" ] }, { "cell_type": "code", "execution_count": 3, "id": "a9885d9f-00ac-4c84-b6a3-b7b648a90f0f", "metadata": {}, "outputs": [], "source": [ "# imports\n", "from collections import Counter\n", "import datetime\n", "import pickle\n", "import subprocess\n", "import seaborn as sns; sns.set()\n", "from datasets import load_from_disk\n", "from sklearn.metrics import accuracy_score, f1_score\n", "from transformers import BertForSequenceClassification\n", "from transformers import Trainer\n", "from transformers.training_args import TrainingArguments\n", "\n", "from geneformer import DataCollatorForCellClassification" ] }, { "cell_type": "markdown", "id": "68bd3b98-5409-4105-b7af-f1ff64ea6a72", "metadata": {}, "source": [ "## Prepare training and evaluation datasets" ] }, { "cell_type": "code", "execution_count": 15, "id": "5735f1b7-7595-4a02-be17-2c5b970ad81a", "metadata": {}, "outputs": [], "source": [ "# load cell type dataset (includes all tissues)\n", "train_dataset=load_from_disk(\"/path/to/cell_type_train_data.dataset\")" ] }, { "cell_type": "code", "execution_count": null, "id": "a4297a02-4c4c-434c-ae55-3387a0b239b5", "metadata": { "collapsed": true, "jupyter": { "outputs_hidden": true }, "tags": [] }, "outputs": [], "source": [ "dataset_list = []\n", "evalset_list = []\n", "organ_list = []\n", "target_dict_list = []\n", "\n", "for organ in Counter(train_dataset[\"organ_major\"]).keys():\n", " # collect list of tissues for fine-tuning (immune and bone marrow are included together)\n", " if organ in [\"bone_marrow\"]: \n", " continue\n", " elif organ==\"immune\":\n", " organ_ids = [\"immune\",\"bone_marrow\"]\n", " organ_list += [\"immune\"]\n", " else:\n", " organ_ids = [organ]\n", " organ_list += [organ]\n", " \n", " print(organ)\n", " \n", " # filter datasets for given organ\n", " def if_organ(example):\n", " return example[\"organ_major\"] in organ_ids\n", " trainset_organ = train_dataset.filter(if_organ, num_proc=16)\n", " \n", " # per scDeepsort published method, drop cell types representing <0.5% of cells\n", " celltype_counter = Counter(trainset_organ[\"cell_type\"])\n", " total_cells = sum(celltype_counter.values())\n", " cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]\n", " def if_not_rare_celltype(example):\n", " return example[\"cell_type\"] in cells_to_keep\n", " trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)\n", " \n", " # shuffle datasets and rename columns\n", " trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)\n", " trainset_organ_shuffled = trainset_organ_shuffled.rename_column(\"cell_type\",\"label\")\n", " trainset_organ_shuffled = trainset_organ_shuffled.remove_columns(\"organ_major\")\n", " \n", " # create dictionary of cell types : label ids\n", " target_names = list(Counter(trainset_organ_shuffled[\"label\"]).keys())\n", " target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))\n", " target_dict_list += [target_name_id_dict]\n", " \n", " # change labels to numerical ids\n", " def classes_to_ids(example):\n", " example[\"label\"] = target_name_id_dict[example[\"label\"]]\n", " return example\n", " labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)\n", " \n", " # create 80/20 train/eval splits\n", " labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])\n", " labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])\n", " \n", " # filter dataset for cell types in corresponding training set\n", " trained_labels = list(Counter(labeled_train_split[\"label\"]).keys())\n", " def if_trained_label(example):\n", " return example[\"label\"] in trained_labels\n", " labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)\n", "\n", " dataset_list += [labeled_train_split]\n", " evalset_list += [labeled_eval_split_subset]" ] }, { "cell_type": "code", "execution_count": 20, "id": "83e20521-597a-4c54-897b-c4d42ea622c2", "metadata": {}, "outputs": [], "source": [ "trainset_dict = dict(zip(organ_list,dataset_list))\n", "traintargetdict_dict = dict(zip(organ_list,target_dict_list))\n", "\n", "evalset_dict = dict(zip(organ_list,evalset_list))" ] }, { "cell_type": "markdown", "id": "10eb110d-ba43-4efc-bc43-1815d6912647", "metadata": {}, "source": [ "## Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance" ] }, { "cell_type": "code", "execution_count": 18, "id": "cd7b1cfb-f5cb-460e-ae77-769522ece054", "metadata": {}, "outputs": [], "source": [ "def compute_metrics(pred):\n", " labels = pred.label_ids\n", " preds = pred.predictions.argmax(-1)\n", " # calculate accuracy and macro f1 using sklearn's function\n", " acc = accuracy_score(labels, preds)\n", " macro_f1 = f1_score(labels, preds, average='macro')\n", " return {\n", " 'accuracy': acc,\n", " 'macro_f1': macro_f1\n", " }" ] }, { "cell_type": "markdown", "id": "beaab7a4-cc13-4e8f-b137-ed18ff7b633c", "metadata": {}, "source": [ "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example hyperparameters are defined below, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications." ] }, { "cell_type": "code", "execution_count": 19, "id": "d24e1ab7-0131-44bd-b458-1ce5ba31853e", "metadata": {}, "outputs": [], "source": [ "# set model parameters\n", "# max input size\n", "max_input_size = 2 ** 11 # 2048\n", "\n", "# set training hyperparameters\n", "# max learning rate\n", "max_lr = 5e-5\n", "# how many pretrained layers to freeze\n", "freeze_layers = 0\n", "# number gpus\n", "num_gpus = 1\n", "# number cpu cores\n", "num_proc = 16\n", "# batch size for training and eval\n", "geneformer_batch_size = 12\n", "# learning schedule\n", "lr_schedule_fn = \"linear\"\n", "# warmup steps\n", "warmup_steps = 500\n", "# number of epochs\n", "epochs = 10\n", "# optimizer\n", "optimizer = \"adamw\"" ] }, { "cell_type": "code", "execution_count": 20, "id": "05164c24-5fbf-4372-b26c-a43f3777a88d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "spleen\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [10280/10280 13:33, Epoch 10/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyMacro F1Weighted F1
10.0870000.0680670.9854040.9568390.985483
20.0444000.0752890.9850790.9550690.984898
30.0667000.0787030.9837820.9532400.983959
40.0374000.0571320.9899450.9706190.989883
50.0250000.0616440.9883230.9611260.988211
60.0224000.0653230.9892960.9697370.989362
70.0186000.0637100.9896200.9694360.989579
80.0398000.0659190.9899450.9680650.989802
90.0302000.0613590.9902690.9717000.990314
100.0134000.0591810.9915670.9745990.991552

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [257/257 00:07]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "kidney\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [29340/29340 45:43, Epoch 10/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyMacro F1Weighted F1
10.3269000.2991930.9125000.8230670.909627
20.2242000.2395800.9264770.8502370.923902
30.2216000.2428100.9302270.8785530.930349
40.1661000.2641780.9334090.8847590.933031
50.1441000.2792820.9350000.8876590.934987
60.1128000.3076470.9359090.8892390.935365
70.0846000.3263990.9328410.8924470.933191
80.0683000.3326260.9365910.8916290.936354
90.0655000.3481740.9352270.8894840.935040
100.0461000.3553500.9350000.8945780.934971

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [734/734 00:27]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "lung\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [21750/21750 30:32, Epoch 10/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyMacro F1Weighted F1
10.3376000.3415230.9063600.7599790.899310
20.2119000.2589540.9284290.8355340.925903
30.2086000.2820810.9304210.8427860.928013
40.1444000.2530470.9354790.8717120.935234
50.1092000.2688330.9394640.8761730.938870
60.1327000.2826970.9405360.8832710.940191
70.0818000.2958640.9408430.8842010.940170
80.0359000.3066000.9419160.8847770.941578
90.0508000.3116770.9405360.8834370.940294
100.0358000.3153600.9408430.8835510.940612

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [544/544 00:19]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "brain\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [8880/8880 11:14, Epoch 10/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyMacro F1Weighted F1
10.1631000.1566400.9703450.7364550.960714
20.1498000.1348970.9688440.7471140.960726
30.1056000.1153540.9722220.7752710.964932
40.0869000.2079180.9688440.7079270.958257
50.0564000.1065480.9740990.8398380.971611
60.0376000.1174370.9782280.8565780.975665
70.0305000.1278850.9744740.8562960.973531
80.0193000.1432030.9778530.8593620.975776
90.0074000.1537580.9725980.8528350.972314
100.0172000.1539110.9759760.8581960.974498

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [222/222 00:04]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "placenta\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [6180/6180 10:28, Epoch 10/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyMacro F1Weighted F1
10.1287000.1251750.9606260.9357520.959463
20.0640000.2156070.9514560.9205790.949828
30.0513000.2030440.9611650.9341950.959470
40.0453000.1157010.9789640.9663870.978788
50.0482000.1494840.9735710.9589270.973305
60.0409000.1343390.9789640.9674660.978899
70.0016000.1599000.9784250.9667130.978211
80.0024000.1253510.9795040.9680640.979428
90.0094000.1201320.9805830.9696310.980506
100.0015000.1378640.9789640.9671800.978825

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [155/155 00:05]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "immune\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [17140/17140 22:02, Epoch 10/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyMacro F1Weighted F1
10.2889000.2315820.9367700.8684050.934816
20.2032000.2062920.9373540.8886610.939555
30.1835000.1958110.9449420.8911490.944008
40.1510000.2195810.9476650.9065780.947093
50.0900000.2471200.9466930.8988120.945808
60.0604000.2496620.9484440.9050140.947975
70.0713000.2727670.9494160.9115140.949748
80.0526000.3050510.9453310.9023480.944987
90.0269000.2941350.9486380.9040580.948296
100.0345000.2920290.9501950.9085470.949753

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [429/429 00:13]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "large_intestine\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [33070/33070 43:02, Epoch 10/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyMacro F1Weighted F1
10.3062000.3124310.9082660.7862420.900768
20.2239000.2480960.9251010.8412510.920987
30.1736000.2599970.9259070.8503480.926290
40.1629000.2823060.9250000.8736690.925531
50.1434000.2544940.9379030.8767490.937836
60.1045000.2899420.9346770.8753330.934339
70.0803000.3139140.9354840.8772710.934986
80.0635000.3398680.9362900.8822670.936187
90.0425000.3457840.9389110.8829630.938682
100.0389000.3521990.9395160.8855090.939497

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [827/827 00:26]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "pancreas\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [18280/18280 23:32, Epoch 10/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyMacro F1Weighted F1
10.3401000.3432000.8962440.6556610.879469
20.1783000.2240330.9308900.8597720.925342
30.1542000.2080340.9412840.8870120.939485
40.1212000.2166600.9403720.8807160.939431
50.0999000.2542550.9405540.8890880.938300
60.0658000.2674290.9427430.8976820.942815
70.0612000.2825090.9454780.8987970.943881
80.0368000.3017810.9438370.9038160.944163
90.0354000.3170260.9425600.9022410.942071
100.0142000.3132590.9467540.9049550.946129

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [457/457 00:11]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "liver\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [18690/18690 26:56, Epoch 10/10]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossAccuracyMacro F1Weighted F1
10.3885000.3855030.8781880.6738870.871348
20.3159000.3027750.9074370.7541820.903474
30.2426000.3218440.9079720.7795040.905881
40.2386000.3231190.9115390.7909220.910299
50.1601000.3282030.9156410.7934900.913836
60.1631000.3489420.9174250.8136040.916911
70.1241000.3737990.9168900.8203550.916688
80.1187000.3994740.9168900.8188390.916640
90.0668000.4143630.9176030.8307030.917226
100.0758000.4138280.9190300.8281490.918506

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n", ":54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", " batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n" ] }, { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [468/468 00:39]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "for organ in organ_list:\n", " print(organ)\n", " organ_trainset = trainset_dict[organ]\n", " organ_evalset = evalset_dict[organ]\n", " organ_label_dict = traintargetdict_dict[organ]\n", " \n", " # set logging steps\n", " logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)\n", " \n", " # reload pretrained model\n", " model = BertForSequenceClassification.from_pretrained(\"/path/to/pretrained_model/\", \n", " num_labels=len(organ_label_dict.keys()),\n", " output_attentions = False,\n", " output_hidden_states = False).to(\"cuda\")\n", " \n", " # define output directory path\n", " current_date = datetime.datetime.now()\n", " datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n", " output_dir = f\"/path/to/models/{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/\"\n", " \n", " # ensure not overwriting previously saved model\n", " saved_model_test = os.path.join(output_dir, f\"pytorch_model.bin\")\n", " if os.path.isfile(saved_model_test) == True:\n", " raise Exception(\"Model already saved to this directory.\")\n", "\n", " # make output directory\n", " subprocess.call(f'mkdir {output_dir}', shell=True)\n", " \n", " # set training arguments\n", " training_args = {\n", " \"learning_rate\": max_lr,\n", " \"do_train\": True,\n", " \"do_eval\": True,\n", " \"evaluation_strategy\": \"epoch\",\n", " \"save_strategy\": \"epoch\",\n", " \"logging_steps\": logging_steps,\n", " \"group_by_length\": True,\n", " \"length_column_name\": \"length\",\n", " \"disable_tqdm\": False,\n", " \"lr_scheduler_type\": lr_schedule_fn,\n", " \"warmup_steps\": warmup_steps,\n", " \"weight_decay\": 0.001,\n", " \"per_device_train_batch_size\": geneformer_batch_size,\n", " \"per_device_eval_batch_size\": geneformer_batch_size,\n", " \"num_train_epochs\": epochs,\n", " \"load_best_model_at_end\": True,\n", " \"output_dir\": output_dir,\n", " }\n", " \n", " training_args_init = TrainingArguments(**training_args)\n", "\n", " # create the trainer\n", " trainer = Trainer(\n", " model=model,\n", " args=training_args_init,\n", " data_collator=DataCollatorForCellClassification(),\n", " train_dataset=organ_trainset,\n", " eval_dataset=organ_evalset,\n", " compute_metrics=compute_metrics\n", " )\n", " # train the cell type classifier\n", " trainer.train()\n", " predictions = trainer.predict(organ_evalset)\n", " with open(f\"{output_dir}predictions.pickle\", \"wb\") as fp:\n", " pickle.dump(predictions, fp)\n", " trainer.save_metrics(\"eval\",predictions.metrics)\n", " trainer.save_model(output_dir)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.11" }, "vscode": { "interpreter": { "hash": "eba1599a1f7e611c14c87ccff6793920aa63510b01fc0e229d6dd014149b8829" } } }, "nbformat": 4, "nbformat_minor": 5 }