NOTE: Notable generation Of patient Text summaries through Efficient approach based on direct preference optimization
The discharge summary (DS) is a crucial document in the patient journey, as it encompasses all events from multiple visits, medications, varied imaging/laboratory tests, surgery/procedures, and admissions/discharge. Providing a summary of the patient’s progress is crucial, as it significantly influences future care and planning. Consequently, clinicians face the laborious and resource-intensive task of manually collecting, organizing, and combining all the necessary data for a DS. Therefore, we propose NOTE, which stands for “Notable generation Of patient Text summaries through an Efficient approach based on direct preference optimization (DPO)”. NOTE is based on MIMIC-III and summarizes a single hospitalization of a patient. Patient events are sequentially combined and used to generate a DS for each hospitalization. To demonstrate the practical application of the developed NOTE, we provide a web page-based demonstration software. In the future, we will aim to deploy the software available for actual use by clinicians in hospital. NOTE can be utilized to generate various summaries not only discharge summaries but also throughout a patient's journey, thereby alleviating the labor-intensive workload of clinicians and aiming for increased efficiency.
Model Description
- Model type: MistralForCausalLM
- Language(s) (NLP): English
- License: CC-BY-NC-SA
- Finetuned from model: mistralai/Mistral-7B-v0.1
Model Sources
Usage
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
model = AutoModelForCausalLM.from_pretrained("jinee/note", load_in_4bit=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("jinee/note")
tokenizer.padding_side = 'right'
tokenizer.add_eos_token = True
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token, tokenizer.add_bos_token
instruction = '''
As a doctor, you need to create a discharge summary based on input data.
Never change the dates or numbers in the input data and use them as is. And please follow the format below for your report.
Also, never make up information that is not in the input data, and write a report only with information that can be identified from the input data.
1. Patient information (SUBJECT_ID, HADM_ID, hospitalization and discharge date, hospitalization period, gender, date of birth, age, allergy)
2. Diagnostic information and past history (if applicable)
3. Surgery or procedure information
4. Significant medication administration during hospitalization and discharge medication history
5. Meaningful lab tests during hospitalization
6. Summary of significant text records/notes
7. Discharge outcomes and treatment plan
8. Overall summary of at least 500 characters in lines including the above contents
'''
torch.cuda.empty_cache()
def generation(model, tokenizer, input_data):
pipe = pipeline('text-generation',
model = model,
tokenizer = tokenizer,
torch_dtype=torch.bfloat16,
device_map = 'auto')
global instruction
sequences = pipe(
f"[INST]{instruction}: {input_data} [/INST]",
do_sample=True,
max_new_tokens=1024,
temperature=0.7,
top_k=50,
top_p=0.95,
early_stopping =True,
num_return_sequences=1,)
text = sequences[0]['generated_text']
start_index = text.find('[/INST]')
if start_index != -1:
summary_ = text[start_index + len('[/INST]'):]
return(summary_)
else:
return("'[summary_] 'is not founded.")
Dataset
The model has been trained on a MIMIC-III, a comprehensive and freely accssible de-identified medical database. Access to this databased requires a number of steps to obtain permission.
Training and Hyper-parameters
List of LoRA config
based on Parameter-Efficient Fine-Tuning (PEFT)
Parameter | SFT | DPO |
---|---|---|
r | 16 | 16 |
lora alpha | 16 | 16 |
lora dropout | 0.05 | 0.05 |
target | q, k, v, o, gate | q, k, v, o, gate |
List of Training arguments
based on Transformer Reinforcement Learning (TRL)
Parameter | SFT | DPO |
---|---|---|
early stopping patience | 3 | 3 |
early stopping threshold | 0.0005 | 0.0005 |
train epochs | 20 | 3 |
per device train batch size | 4 | 1 |
per device eval batch size | 8 (default) | 1 |
optimizer | paged adamw 8bit | paged adamw 8bit |
lr scheduler | cosine | cosine |
wramup ratio | 0.3 | 0.1 |
gradient accumulation step | 2 | 2 |
evaluation strategy | step | step |
eval step | 10 | 5 |
Experimental setup
- Ubuntu 20.04 LTS
- 2 NVIDIA GeForce RTX 3090 GPUs
- Python: 3.8.10
- Pytorch:2.0.1+cu118
- Transformer:4.35.2
Limitations
The model was limited in character count for comparison with the existing T5 model, but it is planned to be expanded in future research. Additionally, further research on prompting engineering is needed due to it producing different results with the same instructions. Most metrics for evaluating summarization and generation tasks were somewhat challenging to apply to our study, and while we attempted to address this through the ChatGPT4 Assistant API, future research will be based on feedback from clinicians.
Non-commercial use
These models are available exclusively for research purposes and are not intended for commercial use.
INMED DATA
INMED DATA is developing large language models (LLMs) specifically tailored for medical applications. For more information, please visit our website [TBD].
- Downloads last month
- 11