Introduction
This repo contains Gemma-2-9b-Medical, a medical language model with 9 billion parameters. This model builds upon the foundation of Gemma-2-9b-base and has been tuned with diverse medical and general instructions. We also use the three strategies in the paper 'Efficient Continual Pre-training by Mitigating the Stability Gap' to mitigate the stability gap during instruction tuning, which boosts the model's medical task performance and reduces the computation consumption.
π» Usage
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
model_name = "YiDuo1999/Gemma-2-9b-medical"
device_map = 'auto'
model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True,use_cache=False,device_map=device_map)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
def askme(question):
sys_message = '''
You are an AI Medical Assistant trained on a vast dataset of health information. Please be thorough and
provide an informative answer. If you don't know the answer to a specific medical inquiry, advise seeking professional help.
'''
# Create messages structured for the chat template
messages = [{"role": "system", "content": sys_message}, {"role": "user", "content": question}]
# Applying chat template
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=100, use_cache=True)
# Extract and return the generated text, removing the prompt
response_text = tokenizer.batch_decode(outputs)[0].strip()
answer = response_text.split('<|im_start|>assistant')[-1].strip()
return answer
π Evaluation
For question-answering tasks, we have
Model | MMLU-Medical | PubMedQA | MedMCQA | MedQA-4-Option | Avg |
---|---|---|---|---|---|
Mistral-7B-instruct | 55.8 | 17.8 | 40.2 | 41.1 | 37.5 |
Zephyr-7B-instruct-Ξ² | 63.3 | 46.0 | 43.0 | 48.5 | 48.7 |
PMC-Llama-7B | 59.7 | 59.2 | 57.6 | 49.2 | 53.6 |
Medalpaca-13B | 55.2 | 50.4 | 21.2 | 20.2 | 36.7 |
AlpaCare-13B | 60.2 | 53.8 | 38.5 | 30.4 | 45.7 |
BioMedGPT-LM 7B | 52.0 | 58.6 | 34.9 | 39.3 | 46.2 |
Me-Llama-13B | - | 70.0 | 44.9 | 42.7 | - |
Llama-3-8B instruct | 82.0 | 74.6 | 57.1 | 60.3 | 68.5 |
JSL-Med-Sft-Llama-3-8B | 83.0 | 75.4 | 57.5 | 74.8 | 72.7 |
GPT-3.5-turbo-1106 | 74.0 | 72.6 | 34.9 | 39.3 | 60.6 |
GPT-4 | 85.5 | 69.2 | 69.5 | 83.9 | 77.0 |
Gemma-2-9b-int | 75.0 | 76.0 | 40.3 | 48.9 | 60.0 |
Gemma-2-9b-Medical | 75.0 | 76.0 | 61.3 | 59.7 | 68.0 |
Llama-3-physician-8B instruct | 80.0 | 76.0 | 80.2 | 60.3 | 74.1 |
Citation
@inproceedings{Guo2024EfficientCP,
title={Efficient Continual Pre-training by Mitigating the Stability Gap},
author={Yiduo Guo and Jie Fu and Huishuai Zhang and Dongyan Zhao and Yikang Shen},
year={2024},
url={https://api.semanticscholar.org/CorpusID:270688100}
}
- Downloads last month
- 27
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.