File size: 3,296 Bytes
5f45235
 
 
 
 
 
111c19f
5f45235
 
 
 
7d3e7dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
---
language: en
tags:
- pytorch
- medical
license: apache-2.0
mask_token: '[MASK]'
widget:
- text: 0.05 ML aflibercept 40 MG/ML Injection is contraindicated with [MASK].
- text: The most import ingrediant for Excedrin is [MASK].
---

## Overview 

This repository contains the bert_base_uncased_rxnorm_babbage model, a continually pretrained [Bert-base-uncased](https://huggingface.co/google-bert/bert-base-uncased) model with drugs, diseases, and their relationships from RxNorm using masked language modeling. 
We hypothesize that the augmentation can boost the model's understanding of medical terminologies and contexts.

It uses a corpus comprising approximately 8.8M million tokens sythesized using drug and disease relations harvested from RxNorm. A few exampes show below.
```plaintext
ferrous fumarate 191 MG is contraindicated with Hemochromatosis.
24 HR metoprolol succinate 50 MG Extended Release Oral Capsule [Kapspargo] contains the ingredient Metoprolol.
Genvoya has the established pharmacologic class Cytochrome P450 3A Inhibitor.
cefprozil 250 MG Oral Tablet may be used to treat Haemophilus Infections.
mecobalamin 1 MG Sublingual Tablet contains the ingredient Vitamin B 12.
```

The dataset is hosted at [this commit](https://github.com/Su-informatics-lab/drug_disease_graph/blob/3a598cb9d55ffbb52d2f16e61eafff4dfefaf5b1/rxnorm.txt). 
Note, this is the babbage version of the corpus using *all* drug and disease relations. 
Don't confuse it with the ada version, where only a fraction of the relationships are used (see [the repo](https://github.com/Su-informatics-lab/drug_disease_graph/tree/main) for more information).

## Training

15% of the data was masked for prediction. 
The model processes this data for *20* epochs.
Training happens on 4 A40(48G) using python3.8 (tried to match up dependencies specified at [requirements.txt](https://github.com/Su-informatics-lab/rxnorm_gatortron/blob/1f15ad349056e22089118519becf1392df084701/requirements.txt)). 
It has a batch size of 16 and a learning rate of 5e-5.
See more configuration at [GitHub](https://github.com/Su-informatics-lab/rxnorm_gatortron/blob/main/runs_mlm_bert_base_uncased_rxnorm_babbage.sh) and training curves at [WandB](https://wandb.ai/hainingwang/continual_pretraining_gatortron/runs/1abzfvb9).


## Usage 
You can use this model for masked language modeling tasks to predict missing words in a given text.
Below are the instructions and examples to get you started.

```python
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

# load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Su-informatics-lab/bert_base_uncased_rxnorm_babbage")
model = AutoModelForMaskedLM.from_pretrained("Su-informatics-lab/bert_base_uncased_rxnorm_babbage")

# prepare the input
text = "0.05 ML aflibercept 40 MG/ML Injection is contraindicated with [MASK]."
inputs = tokenizer(text, return_tensors="pt")

# get model predictions
with torch.no_grad():
    outputs = model(**inputs)

# decode the predictions
predictions = outputs.logits
masked_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
predicted_token_id = predictions[0, masked_index].argmax(axis=-1)
predicted_token = tokenizer.decode(predicted_token_id)
```

## License
Apache 2.0.