--- model_name: mistralmed library_name: peft base_model: mistralai/Mistral-7B-v0.1 license: mit datasets: - keivalya/MedQuad-MedicalQnADataset language: - en tags: - medical --- # Model Card for Tonic/MistralMed This is a medicine-focussed mistral fine tuned using keivalya/MedQuad-MedicalQnADataset ## Model Details ### Model Description Trying to get better at medical Q & A - **Developed by:** [Tonic](https://huggingface.co/Tonic) - **Shared by :** [Tonic](https://huggingface.co/Tonic) - **Model type:** Mistral Fine-Tune - **Language(s) (NLP):** English - **License:** MIT2.0 - **Finetuned from model :** [mistralai/Mistral-7B-v0.1](https://huggingface.com/Mistralai/Mistral-7B-v0.1) ### Model Sources - **Repository:** [Tonic/mistralmed](https://huggingface.co/Tonic/mistralmed) - **Code :** [github](https://github.com/Josephrp/mistralmed/blob/main/finetuning.py) - **Demo :** [Tonic/MistralMed_Chat](https://huggingface.co/Tonic/MistralMed_Chat) ## Uses This model can be used the same way you normally use mistral ### Direct Use This model can do better in medical question and answer scenarios. ### Downstream Use This model is intended to be further fine tuned. ### Recommendations - Do Not Use As Is - Fine Tune This Model Further - For Educational Purposes Only - Benchmark your model usage - Evaluate the model before use Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations. ## How to Get Started with the Model Use the code below to get started with the model. [pseudolab/MistralMED_Chat](https://huggingface.co/spaces/pseudolab/MistralMED_Chat) ```python from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM from peft import PeftModel, PeftConfig import torch import gradio as gr import random from textwrap import wrap # Functions to Wrap the Prompt Correctly def wrap_text(text, width=90): lines = text.split('\n') wrapped_lines = [textwrap.fill(line, width=width) for line in lines] wrapped_text = '\n'.join(wrapped_lines) return wrapped_text def multimodal_prompt(user_input, system_prompt="You are an expert medical analyst:"): """ Generates text using a large language model, given a user input and a system prompt. Args: user_input: The user's input text to generate a response for. system_prompt: Optional system prompt. Returns: A string containing the generated text. """ # Combine user input and system prompt formatted_input = f"[INST]{system_prompt} {user_input}[/INST]" # Encode the input text encodeds = tokenizer(formatted_input, return_tensors="pt", add_special_tokens=False) model_inputs = encodeds.to(device) # Generate a response using the model output = model.generate( **model_inputs, max_length=max_length, use_cache=True, early_stopping=True, bos_token_id=model.config.bos_token_id, eos_token_id=model.config.eos_token_id, pad_token_id=model.config.eos_token_id, temperature=0.1, do_sample=True ) # Decode the response response_text = tokenizer.decode(output[0], skip_special_tokens=True) return response_text # Define the device device = "cuda" if torch.cuda.is_available() else "cpu" # Use the base model's ID base_model_id = "mistralai/Mistral-7B-v0.1" model_directory = "Tonic/mistralmed" # Instantiate the Tokenizer tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", trust_remote_code=True, padding_side="left") # tokenizer = AutoTokenizer.from_pretrained("Tonic/mistralmed", trust_remote_code=True, padding_side="left") tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = 'left' # Specify the configuration class for the model #model_config = AutoConfig.from_pretrained(base_model_id) # Load the PEFT model with the specified configuration #peft_model = AutoModelForCausalLM.from_pretrained(base_model_id, config=model_config) # Load the PEFT model peft_config = PeftConfig.from_pretrained("Tonic/mistralmed", token="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") peft_model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", trust_remote_code=True) peft_model = PeftModel.from_pretrained(peft_model, "Tonic/mistralmed", token="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") class ChatBot: def __init__(self): self.history = [] def predict(self, user_input, system_prompt="You are an expert medical analyst:"): # Combine user input and system prompt formatted_input = f"[INST]{system_prompt} {user_input}[/INST]" # Encode user input user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt") # Concatenate the user input with chat history if len(self.history) > 0: chat_history_ids = torch.cat([self.history, user_input_ids], dim=-1) else: chat_history_ids = user_input_ids # Generate a response using the PEFT model response = peft_model.generate(input_ids=chat_history_ids, max_length=512, pad_token_id=tokenizer.eos_token_id) # Update chat history self.history = chat_history_ids # Decode and return the response response_text = tokenizer.decode(response[0], skip_special_tokens=True) return response_text bot = ChatBot() title = "👋🏻Welcome to Tonic's MistralMed Chat🚀" description = "You can use this Space to test out the current model (MistralMed) or duplicate this Space and use it for any other model on 🤗HuggingFace. Join me on Discord to build together." examples = [["What is the proper treatment for buccal herpes?", "Please provide information on the most effective antiviral medications and home remedies for treating buccal herpes."]] iface = gr.Interface( fn=bot.predict, title=title, description=description, examples=examples, inputs=["text", "text"], # Take user input and system prompt separately outputs="text", theme="ParityError/Anime" ) iface.launch() ``` ## Training Details ### Training Data [MedQuad](https://huggingface.co/datasets/keivalya/MedQuad-MedicalQnADataset/viewer/default/train) ### Training Procedure ```json Dataset({ features: ['qtype', 'Question', 'Answer'], num_rows: 16407 }) ``` #### Preprocessing [optional] ```json MistralForCausalLM( (model): MistralModel( (embed_tokens): Embedding(32000, 4096) (layers): ModuleList( (0-31): 32 x MistralDecoderLayer( (self_attn): MistralAttention( (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False) (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False) (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False) (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False) (rotary_emb): MistralRotaryEmbedding() ) (mlp): MistralMLP( (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False) (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False) (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False) (act_fn): SiLUActivation() ) (input_layernorm): MistralRMSNorm() (post_attention_layernorm): MistralRMSNorm() ) ) (norm): MistralRMSNorm() ) (lm_head): Linear(in_features=4096, out_features=32000, bias=False) ) ``` #### Training Hyperparameters - **Training regime:** ```json config = LoraConfig( r=8, lora_alpha=16, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head", ], bias="none", lora_dropout=0.05, # Conventional task_type="CAUSAL_LM", ) ``` #### Speeds, Sizes, Times [optional] - trainable params: 21260288 || all params: 3773331456 || trainable%: 0.5634354746703705 - TrainOutput(global_step=1000, training_loss=0.47226515007019043, metrics={'train_runtime': 3143.4141, 'train_samples_per_second': 2.545, 'train_steps_per_second': 0.318, 'total_flos': 1.75274075357184e+17, 'train_loss': 0.47226515007019043, 'epoch': 0.49}) ## Environmental Impact Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). - **Hardware Type:** A100 - **Hours used:** 1 - **Cloud Provider:** Google - **Compute Region:** East1 - **Carbon Emitted:** 0.09 ## Training Results [1000/1000 52:20, Epoch 0/1] | Step | Training Loss | |-------|--------------| | 50 | 0.474200 | | 100 | 0.523300 | | 150 | 0.484500 | | 200 | 0.482800 | | 250 | 0.498800 | | 300 | 0.451800 | | 350 | 0.491800 | | 400 | 0.488000 | | 450 | 0.472800 | | 500 | 0.460400 | | 550 | 0.464700 | | 600 | 0.484800 | | 650 | 0.474600 | | 700 | 0.477900 | | 750 | 0.445300 | | 800 | 0.431300 | | 850 | 0.461500 | | 900 | 0.451200 | | 950 | 0.470800 | | 1000 | 0.454900 | ### Model Architecture and Objective ```json PeftModelForCausalLM( (base_model): LoraModel( (model): MistralForCausalLM( (model): MistralModel( (embed_tokens): Embedding(32000, 4096) (layers): ModuleList( (0-31): 32 x MistralDecoderLayer( (self_attn): MistralAttention( (q_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=4096, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False) ) (k_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=1024, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False) ) (v_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=1024, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False) ) (o_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=4096, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False) ) (rotary_emb): MistralRotaryEmbedding() ) (mlp): MistralMLP( (gate_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=14336, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=4096, out_features=14336, bias=False) ) (up_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=14336, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=4096, out_features=14336, bias=False) ) (down_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=14336, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=4096, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=14336, out_features=4096, bias=False) ) (act_fn): SiLUActivation() ) (input_layernorm): MistralRMSNorm() (post_attention_layernorm): MistralRMSNorm() ) ) (norm): MistralRMSNorm() ) (lm_head): Linear( in_features=4096, out_features=32000, bias=False (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=32000, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() ) ) ) ) ``` #### Hardware A100 ## Model Card Authors [optional] [Tonic](https://huggingface.co/Tonic) ## Model Card Contact [Tonic](https://huggingface.co/Tonic) ## Training procedure The following `bitsandbytes` quantization config was used during training: - quant_method: bitsandbytes - load_in_8bit: False - load_in_4bit: True - llm_int8_threshold: 6.0 - llm_int8_skip_modules: None - llm_int8_enable_fp32_cpu_offload: False - llm_int8_has_fp16_weight: False - bnb_4bit_quant_type: nf4 - bnb_4bit_use_double_quant: True - bnb_4bit_compute_dtype: bfloat16 ### Framework versions - PEFT 0.6.0.dev0