WellcomeBertMesh
WellcomeBertMesh is build from the data science team at the WellcomeTrust to tag biomedical grants with Medical Subject Headings (Mesh). Even though developed with the intention to be used towards research grants, it should be applicable to any type of biomedical text close to the domain it was trained which is abstracts from biomedical publications.
Model description
The model is inspired from BertMesh which is trained on the full text of biomedical publications and uses BioBert as its pretrained model.
WellcomeBertMesh is utilising the latest state of the art model in the biomedical domain which is PubMedBert from Microsoft and attach a Multilabel attention head which essentially allows the model to pay attention to different tokens per label to decide whether it applies.
We train the model using data from the BioASQ competition which consists of abstracts from PubMed publications. We use 2016-2019 data for training and 2020-2021 for testing which gives us ~2.5M publications to train and 220K to test. This is out of a total of 14M publications. It takes 4 days to train WellcomeBertMesh on 8 Nvidia P100 GPUs.
The model achieves 63% micro f1 with a 0.5 threshold for all labels.
The code for developing the model is open source and can be found in https://github.com/wellcometrust/grants_tagger
How to use
⚠️ You need transformers 4.17+ for the example to work due to its recent support for custom models.
You can use the model straight from the hub but because it contains a custom forward function due to the multilabel attention head you have to pass trust_remote_code=True
. You can get access to the probabilities for all labels by omitting return_labels=True
.
from transformers import AutoModel, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(
"Wellcome/WellcomeBertMesh"
)
model = AutoModel.from_pretrained(
"Wellcome/WellcomeBertMesh",
trust_remote_code=True
)
text = "This grant is about malaria and not about HIV."
inputs = tokenizer([text], padding="max_length")
labels = model(**inputs, return_labels=True)
print(labels)
You can inspect the model code if you navigate to the files and see model.py
.
- Downloads last month
- 5,148