File size: 28,925 Bytes
7934b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n",
    "\n",
    "Instructions for setting up Colab are as follows:\n",
    "1. Open a new Python 3 notebook.\n",
    "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n",
    "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n",
    "4. Run this cell to set up dependencies.\n",
    "\"\"\"\n",
    "\n",
    "## Install NeMo if using google collab or if its not installed locally\n",
    "BRANCH = 'r1.17.0'\n",
    "!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Install dependencies\n",
    "!pip install wget\n",
    "!pip install faiss-gpu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import faiss\n",
    "import torch\n",
    "import wget\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from omegaconf import OmegaConf\n",
    "from pytorch_lightning import Trainer\n",
    "from IPython.display import display\n",
    "from tqdm import tqdm\n",
    "\n",
    "from nemo.collections import nlp as nemo_nlp\n",
    "from nemo.utils.exp_manager import exp_manager"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Entity Linking"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Task Description\n",
    "[Entity linking](https://en.wikipedia.org/wiki/Entity_linking) is the process of connecting concepts mentioned in natural language to their canonical forms stored in a knowledge base. For example, say a knowledge base contained the entity 'ID3452 influenza' and we wanted to process some natural language containing the sentence \"The patient has flu like symptoms\". An entity linking model would match the word 'flu' to the knowledge base entity 'ID3452 influenza', allowing for disambiguation and normalization of concepts referenced in text. Entity linking applications range from helping automate data ingestion to assisting in real time dialogue concept normalization. We will be focusing on entity linking in the medical domain for this demo, but the entity linking model, dataset, and training code within NVIDIA NeMo can be applied to other domains like finance and retail.\n",
    "\n",
    "Within NeMo and this tutorial we use the entity linking approach described in Liu et. al's NAACL 2021 \"[Self-alignment Pre-training for Biomedical Entity Representations](https://arxiv.org/abs/2010.11784v2)\". The main idea behind this approach is to reshape an initial concept embedding space such that synonyms of the same concept are pulled closer together and unrelated concepts are pushed further apart. The concept embeddings from this reshaped space can then be used to build a knowledge base embedding index. This index stores concept IDs mapped to their respective concept embeddings in a format conducive to efficient nearest neighbor search. We can link query concepts to their canonical forms in the knowledge base by performing a nearest neighbor search- matching concept query embeddings to the most similar concepts embeddings in the knowledge base index. \n",
    "\n",
    "In this tutorial we will be using the [faiss](https://github.com/facebookresearch/faiss) library to build our concept index."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Self Alignment Pretraining\n",
    "Self-Alignment pretraining is a second stage pretraining of an existing encoder (called second stage because the encoder model can be further finetuned after this more general pretraining step). The dataset used during training consists of pairs of concept synonyms that map to the same ID. At each training iteration, we only select *hard* examples present in the mini batch to calculate the loss and update the model weights. In this context, a hard example is an example where a concept is closer to an unrelated concept in the mini batch than it is to the synonym concept it is paired with by some margin. I encourage you to take a look at [section 2 of the paper](https://arxiv.org/pdf/2010.11784.pdf) for a more formal and in depth description of how hard examples are selected.\n",
    "\n",
    "We then use a [metric learning loss](https://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf) calculated from the hard examples selected. This loss helps reshape the embedding space. The concept representation space is rearranged to be more suitable for entity matching via embedding cosine similarity. \n",
    "\n",
    "Now that we have idea of what's going on, let's get started!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset Preprocessing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download data into project directory\n",
    "PROJECT_DIR = \".\" #Change if you don't want the current directory to be the project dir\n",
    "DATA_DIR = os.path.join(PROJECT_DIR, \"tiny_example_data\")\n",
    "\n",
    "if not os.path.isdir(os.path.join(DATA_DIR)):\n",
    "    wget.download('https://dldata-public.s3.us-east-2.amazonaws.com/tiny_example_data.zip',\n",
    "                  os.path.join(PROJECT_DIR, \"tiny_example_data.zip\"))\n",
    "\n",
    "    !unzip {PROJECT_DIR}/tiny_example_data.zip -d {PROJECT_DIR}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this tutorial we will be using a tiny toy dataset to demonstrate how to use NeMo's entity linking model functionality. The dataset includes synonyms for 12 medical concepts. Entity phrases with the same ID are synonyms for the same concept. For example, \"*chronic kidney failure*\", \"*gradual loss of kidney function*\", and \"*CKD*\" are all synonyms of concept ID 5. Here's the dataset before preprocessing:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_data = pd.read_csv(os.path.join(DATA_DIR, \"tiny_example_dev_data.csv\"), names=[\"ID\", \"CONCEPT\"], index_col=False)\n",
    "print(raw_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We've already paired off the concepts for this dataset with the format `ID concept_synonym1 concept_synonym2`. Here are the first ten rows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_data = pd.read_table(os.path.join(DATA_DIR, \"tiny_example_train_pairs.tsv\"), names=[\"ID\", \"CONCEPT_SYN1\", \"CONCEPT_SYN2\"], delimiter='\\t')\n",
    "print(training_data.head(10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use the [Unified Medical Language System (UMLS)](https://www.nlm.nih.gov/research/umls/index.html) dataset for full medical domain entity linking training. The data contains over 9 million entities and is a table of medical concepts with their corresponding concept IDs (CUI). After [requesting a free license and making a UMLS Terminology Services (UTS) account](https://www.nlm.nih.gov/research/umls/index.html), the [entire UMLS dataset](https://www.nlm.nih.gov/research/umls/licensedcontent/umlsknowledgesources.html) can be downloaded from the NIH's website. If you've cloned the NeMo repo you can run the data processing script located in `examples/nlp/entity_linking/data/umls_dataset_processing.py` on the full dataset. This script will take in the initial table of UMLS concepts and produce a .tsv file with each row formatted as `CUI\\tconcept_synonym1\\tconcept_synonym2`. Once the UMLS dataset .RRF file is downloaded, the script can be run from the `examples/nlp/entity_linking` directory like so: \n",
    "```\n",
    "python data/umls_dataset_processing.py\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Second stage pretrain a BERT Base encoder on the self-alignment pretraining task (SAP) for improved entity linking. Using a GPU, the model should take 5 minutes or less to train on this example dataset and training progress will be output below the cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Download config\n",
    "wget.download(f\"https://raw.githubusercontent.com/NVIDIA/NeMo/{BRANCH}/examples/nlp/entity_linking/conf/tiny_example_entity_linking_config.yaml\",\n",
    "              os.path.join(PROJECT_DIR, \"tiny_example_entity_linking_config.yaml\"))\n",
    "\n",
    "# Load in config file\n",
    "cfg = OmegaConf.load(os.path.join(PROJECT_DIR, \"tiny_example_entity_linking_config.yaml\"))\n",
    "\n",
    "# Set config file variables\n",
    "cfg.project_dir = PROJECT_DIR\n",
    "cfg.model.nemo_path = os.path.join(PROJECT_DIR, \"tiny_example_sap_bert_model.nemo\")\n",
    "cfg.model.train_ds.data_file = os.path.join(DATA_DIR, \"tiny_example_train_pairs.tsv\")\n",
    "cfg.model.validation_ds.data_file = os.path.join(DATA_DIR, \"tiny_example_validation_pairs.tsv\")\n",
    "\n",
    "# remove distributed training flags\n",
    "cfg.trainer.strategy = None\n",
    "cfg.trainer.accelerator = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize the trainer and model\n",
    "trainer = Trainer(**cfg.trainer)\n",
    "exp_manager(trainer, cfg.get(\"exp_manager\", None))\n",
    "model = nemo_nlp.models.EntityLinkingModel(cfg=cfg.model, trainer=trainer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train and save the model\n",
    "trainer.fit(model)\n",
    "model.save_to(cfg.model.nemo_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can run the script at `examples/nlp/entity_linking/self_alignment_pretraining.py` to train a model on a larger dataset. Run\n",
    "\n",
    "```\n",
    "python self_alignment_pretraining.py project_dir=.\n",
    "```\n",
    "from the `examples/nlp/entity_linking` directory."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model Evaluation\n",
    "\n",
    "Let's evaluate our freshly trained model and compare its performance with a BERT Base encoder that hasn't undergone self-alignment pretraining. We first need to restore our trained model and load our BERT Base Baseline model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# Restore second stage pretrained model\n",
    "sap_model_cfg = cfg\n",
    "sap_model_cfg.index.index_save_name = os.path.join(PROJECT_DIR, \"tiny_example_entity_linking_index\")\n",
    "sap_model_cfg.index.index_ds.data_file = os.path.join(DATA_DIR, \"tiny_example_index_data.tsv\")\n",
    "sap_model = nemo_nlp.models.EntityLinkingModel.restore_from(sap_model_cfg.model.nemo_path).to(device)\n",
    "\n",
    "# Load original model\n",
    "base_model_cfg = OmegaConf.load(os.path.join(PROJECT_DIR, \"tiny_example_entity_linking_config.yaml\"))\n",
    "\n",
    "# Set train/val datasets to None to avoid loading datasets associated with training\n",
    "base_model_cfg.model.train_ds = None\n",
    "base_model_cfg.model.validation_ds = None\n",
    "base_model_cfg.index.index_save_name = os.path.join(PROJECT_DIR, \"base_model_index\")\n",
    "base_model_cfg.index.index_ds.data_file = os.path.join(DATA_DIR, \"tiny_example_index_data.tsv\")\n",
    "base_model = nemo_nlp.models.EntityLinkingModel(base_model_cfg.model).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We are going evaluate our model on a nearest neighbor task using top 1 and top 5 accuracies as our metric. We will be using a tiny example test knowledge base and test queries. For this evaluation we are going to be comparing every test query with every concept vector in our test set knowledge base. We will rank each item in the knowledge base by its cosine similarity with the test query. We'll then compare the IDs of the predicted most similar test knowledge base concepts with our ground truth query IDs to calculate top 1 and top 5 accuracies. For this metric higher is better."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Helper function to get data embeddings\n",
    "def get_embeddings(model, dataloader):\n",
    "    embeddings, cids = [], []\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for batch in tqdm(dataloader):\n",
    "            input_ids, token_type_ids, attention_mask, batch_cids = batch\n",
    "            batch_embeddings = model.forward(input_ids=input_ids.to(device), \n",
    "                                             token_type_ids=token_type_ids.to(device), \n",
    "                                             attention_mask=attention_mask.to(device))\n",
    "\n",
    "            # Accumulate index embeddings and their corresponding IDs\n",
    "            embeddings.extend(batch_embeddings.cpu().detach().numpy())\n",
    "            cids.extend(batch_cids)\n",
    "            \n",
    "    return embeddings, cids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, test_kb, test_queries, ks):\n",
    "    # Initialize knowledge base and query data loaders\n",
    "    test_kb_dataloader = model.setup_dataloader(test_kb, is_index_data=True)\n",
    "    test_query_dataloader = model.setup_dataloader(test_queries, is_index_data=True)\n",
    "    \n",
    "    # Get knowledge base and query embeddings\n",
    "    test_kb_embs, test_kb_cids = get_embeddings(model, test_kb_dataloader)\n",
    "    test_query_embs, test_query_cids = get_embeddings(model, test_query_dataloader)\n",
    "\n",
    "    # Calculate the cosine distance between each query and knowledge base concept\n",
    "    score_matrix = np.matmul(np.array(test_query_embs), np.array(test_kb_embs).T)\n",
    "    accs = {k : 0 for k in ks}\n",
    "    \n",
    "    # Compare the knowledge base IDs of the knowledge base entities with \n",
    "    # the smallest cosine distance from the query \n",
    "    for query_idx in tqdm(range(len(test_query_cids))):\n",
    "        query_emb = test_query_embs[query_idx]\n",
    "        query_cid = test_query_cids[query_idx]\n",
    "        query_scores = score_matrix[query_idx]\n",
    "\n",
    "        for k in ks:\n",
    "            topk_idxs = np.argpartition(query_scores, -k)[-k:]\n",
    "            topk_cids = [test_kb_cids[idx] for idx in topk_idxs]\n",
    "            \n",
    "            # If the correct query ID is amoung the top k closest kb IDs\n",
    "            # the model correctly linked the entity\n",
    "            match = int(query_cid in topk_cids)\n",
    "            accs[k] += match\n",
    "\n",
    "    for k in ks:\n",
    "        accs[k] /= len(test_query_cids)\n",
    "                \n",
    "    return accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create configs for our test data\n",
    "test_kb = OmegaConf.create({\n",
    "    \"data_file\": os.path.join(DATA_DIR, \"tiny_example_test_kb.tsv\"),\n",
    "    \"max_seq_length\": 128,\n",
    "    \"batch_size\": 10,\n",
    "    \"shuffle\": False,\n",
    "})\n",
    "\n",
    "test_queries = OmegaConf.create({\n",
    "    \"data_file\": os.path.join(DATA_DIR, \"tiny_example_test_queries.tsv\"),\n",
    "    \"max_seq_length\": 128,\n",
    "    \"batch_size\": 10,\n",
    "    \"shuffle\": False,\n",
    "})\n",
    "\n",
    "ks = [1, 5]\n",
    "\n",
    "# Evaluate both models on our test data\n",
    "base_accs = evaluate(base_model, test_kb, test_queries, ks)\n",
    "base_accs[\"Model\"] = \"BERT Base Baseline\"\n",
    "\n",
    "sap_accs = evaluate(sap_model, test_kb, test_queries, ks)\n",
    "sap_accs[\"Model\"] = \"BERT + SAP\"\n",
    "\n",
    "print(\"Top 1 and Top 5 Accuracy Comparison:\")\n",
    "results_df = pd.DataFrame([base_accs, sap_accs], columns=[\"Model\", 1, 5])\n",
    "results_df = results_df.style.set_properties(**{'text-align': 'left', }).set_table_styles([dict(selector='th', props=[('text-align', 'left')])])\n",
    "display(results_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The purpose of this section was to show an example of evaluating your entity linking model. This evaluation set contains very little data, and no serious conclusions should be drawn about model performance. Top 1 accuracy should be between 0.7 and 1.0 for both models and top 5 accuracy should be between 0.8 and 1.0. When evaluating a model trained on a larger dataset, you can use a nearest neighbors index to speed up the evaluation time."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Building an Index"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To qualitatively observe the improvement we gain from the second stage pretraining, let's build two indices. One will be built with BERT base embeddings before self-alignment pretraining and one will be built with the model we just trained. Our knowledge base in this tutorial will be in the same domain and have some overlapping concepts as the training set. This data file is formatted as `ID\\tconcept`."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `EntityLinkingDataset` class can load the data used for training the entity linking encoder as well as for building the index if the `is_index_data` flag is set to true. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_index(cfg, model):\n",
    "    # Setup index dataset loader\n",
    "    index_dataloader = model.setup_dataloader(cfg.index.index_ds, is_index_data=True)\n",
    "    \n",
    "    # Get index dataset embeddings\n",
    "    embeddings, _ = get_embeddings(model, index_dataloader)\n",
    "    \n",
    "    # Train IVFFlat index using faiss\n",
    "    embeddings = np.array(embeddings)\n",
    "    quantizer = faiss.IndexFlatL2(cfg.index.dims)\n",
    "    index = faiss.IndexIVFFlat(quantizer, cfg.index.dims, cfg.index.nlist)\n",
    "    index = faiss.index_cpu_to_all_gpus(index)\n",
    "    index.train(embeddings)\n",
    "    \n",
    "    # Add concept embeddings to index\n",
    "    for i in tqdm(range(0, embeddings.shape[0], cfg.index.index_batch_size)):\n",
    "            index.add(embeddings[i:i+cfg.index.index_batch_size])\n",
    "\n",
    "    # Save index\n",
    "    faiss.write_index(faiss.index_gpu_to_cpu(index), cfg.index.index_save_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "build_index(sap_model_cfg, sap_model.to(device))\n",
    "build_index(base_model_cfg, base_model.to(device))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Entity Linking via Nearest Neighbor Search"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now it's time to query our indices! We are going to query both our index built with embeddings from BERT Base, and our index with embeddings built from the SAP BERT model we trained. Our sample query phrases will be \"*high blood sugar*\" and \"*head pain*\". \n",
    "\n",
    "To query our indices, we first need to get the embedding of each query from the corresponding encoder model. We can then pass these query embeddings into the faiss index which will perform a nearest neighbor search, using cosine distance to compare the query embedding with embeddings present in the index. Once we get a list of knowledge base index concept IDs most closely matching our query, all that is left to do is map the IDs to a representative string describing the concept. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def query_index(cfg, model, index, queries, id2string):\n",
    "    # Get query embeddings from our entity linking encoder model\n",
    "    query_embs = get_query_embedding(queries, model).cpu().detach().numpy()\n",
    "    \n",
    "    # Use query embedding to find closest concept embedding in knowledge base\n",
    "    distances, neighbors = index.search(query_embs, cfg.index.top_n)\n",
    "    \n",
    "    # Get the canonical strings corresponding to the IDs of the query's nearest neighbors in the kb \n",
    "    neighbor_concepts = [[id2string[concept_id] for concept_id in query_neighbor] \\\n",
    "                                                for query_neighbor in neighbors]\n",
    "    \n",
    "    # Display most similar concepts in the knowledge base. \n",
    "    for query_idx in range(len(queries)):\n",
    "        print(f\"\\nThe most similar concepts to {queries[query_idx]} are:\")\n",
    "        for cid, concept, dist in zip(neighbors[query_idx], neighbor_concepts[query_idx], distances[query_idx]):\n",
    "            print(cid, concept, 1 - dist)\n",
    "\n",
    "    \n",
    "def get_query_embedding(queries, model):\n",
    "    # Tokenize our queries\n",
    "    model_input =  model.tokenizer(queries,\n",
    "                                   add_special_tokens = True,\n",
    "                                   padding = True,\n",
    "                                   truncation = True,\n",
    "                                   max_length = 512,\n",
    "                                   return_token_type_ids = True,\n",
    "                                   return_attention_mask = True)\n",
    "    \n",
    "    # Pass tokenized input into model\n",
    "    query_emb =  model.forward(input_ids=torch.LongTensor(model_input[\"input_ids\"]).to(device),\n",
    "                               token_type_ids=torch.LongTensor(model_input[\"token_type_ids\"]).to(device),\n",
    "                               attention_mask=torch.LongTensor(model_input[\"attention_mask\"]).to(device))\n",
    "    \n",
    "    return query_emb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load indices\n",
    "sap_index = faiss.read_index(sap_model_cfg.index.index_save_name)\n",
    "base_index = faiss.read_index(base_model_cfg.index.index_save_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Map concept IDs to one canonical string\n",
    "index_data = open(sap_model_cfg.index.index_ds.data_file, \"r\", encoding='utf-8-sig')\n",
    "id2string = {}\n",
    "\n",
    "for line in index_data:\n",
    "    cid, concept = line.split(\"\\t\")\n",
    "    id2string[int(cid) - 1] = concept.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "id2string"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Some sample queries\n",
    "queries = [\"high blood sugar\", \"head pain\"]\n",
    "\n",
    "# Query BERT Base\n",
    "print(\"BERT Base output before Self Alignment Pretraining:\")\n",
    "query_index(base_model_cfg, base_model, base_index, queries, id2string)\n",
    "print(\"\\n\" + \"-\" * 50 + \"\\n\")\n",
    "\n",
    "# Query SAP BERT\n",
    "print(\"SAP BERT output after Self Alignment Pretraining:\")\n",
    "query_index(sap_model_cfg, sap_model, sap_index, queries, id2string)\n",
    "print(\"\\n\" + \"-\" * 50 + \"\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Even after only training on this tiny amount of data, the qualitative performance boost from self-alignment pretraining is visible. The baseline model links \"*high blood sugar*\" to the entity \"*6 diabetes*\" while our SAP BERT model accurately links \"*high blood sugar*\" to \"*Hyperinsulinemia*\". Similarly, \"*head pain*\" and \"*Myocardial infraction*\" are not the same concept, but \"*head pain*\" and \"*Headache*\" are."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "For larger knowledge bases keeping the default embedding size might be too large and cause out of memory issues. You can apply PCA or some other dimensionality reduction method to your data to reduce its memory footprint. Code for creating a text file of all the UMLS entities in the correct format needed to build an index and creating a dictionary mapping concept ids to canonical concept strings can be found here `examples/nlp/entity_linking/data/umls_dataset_processing.py`. \n",
    "\n",
    "The code for extracting knowledge base concept embeddings, training and applying a PCA transformation to the embeddings, building a faiss index and querying the index from the command line is located at `examples/nlp/entity_linking/build_index.py` and `examples/nlp/entity_linking/query_index.py`. \n",
    "\n",
    "If you've cloned the NeMo repo, both of these steps can be run as follows on the command line from the `examples/nlp/entity_linking/` directory.\n",
    "\n",
    "```\n",
    "python data/umls_dataset_processing.py --index\n",
    "python build_index.py --restore\n",
    "python query_index.py --restore\n",
    "```\n",
    "By default the project directory will be \".\" but can be changed by adding the flag `--project_dir=<PATH>` after each of the above commands. Intermediate steps of the index building process are saved. In the occurrence of an error, previously completed steps do not need to be rerun. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Command Recap"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is a recap of the commands and steps to repeat this process on the full UMLS dataset. \n",
    "\n",
    "1) Download the UMLS dataset file `MRCONSO.RRF` from the NIH website and place it in the `examples/nlp/entity_linking/data` directory.\n",
    "\n",
    "2) Run the following commands from the `examples/nlp/entity_linking` directory\n",
    "```\n",
    "python data/umls_dataset_processing.py\n",
    "python self_alignment_pretraining.py project_dir=. \n",
    "python data/umls_dataset_processing.py --index\n",
    "python build_index.py --restore\n",
    "python query_index.py --restore\n",
    "```\n",
    "The model will take ~24hrs to train on two GPUs and ~48hrs to train on one GPU. By default the project directory will be \".\" but can be changed by adding the flag `--project_dir=<PATH>` after each of the above commands and changing `project_dir=<PATH>` in the `self_alignment_pretraining.py` command. If you change the project directory, you should also move the `MRCONOSO.RRF` file to a `data` sub directory within the one you've specified. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As mentioned in the introduction, entity linking within NVIDIA NeMo is not limited to the medical domain. The same data processing and training steps can be applied to a variety of domains and use cases. You can edit the datasets used as well as training and loss function hyperparameters within your config file to better suit your domain."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}