PEFT
English
medical
mistralmed / README.md
Tonic's picture
Update README.md
4fee1a9
metadata
model_name: mistralmed
library_name: peft
base_model: mistralai/Mistral-7B-v0.1
license: mit
datasets:
  - keivalya/MedQuad-MedicalQnADataset
language:
  - en
tags:
  - medical

Model Card for Tonic/MistralMed

This is a medicine-focussed mistral fine tuned using keivalya/MedQuad-MedicalQnADataset

Model Details

Model Description

Trying to get better at medical Q & A

Model Sources

Uses

This model can be used the same way you normally use mistral

Direct Use

This model can do better in medical question and answer scenarios.

Downstream Use

This model is intended to be further fine tuned.

Recommendations

  • Do Not Use As Is
  • Fine Tune This Model Further
  • For Educational Purposes Only
  • Benchmark your model usage
  • Evaluate the model before use

Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.

How to Get Started with the Model

Use the code below to get started with the model.

pseudolab/MistralMED_Chat

from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
from peft import PeftModel, PeftConfig
import torch
import gradio as gr
import random
from textwrap import wrap


# Functions to Wrap the Prompt Correctly
def wrap_text(text, width=90):
    lines = text.split('\n')
    wrapped_lines = [textwrap.fill(line, width=width) for line in lines]
    wrapped_text = '\n'.join(wrapped_lines)
    return wrapped_text

def multimodal_prompt(user_input, system_prompt="You are an expert medical analyst:"):
    """
    Generates text using a large language model, given a user input and a system prompt.
    Args:
        user_input: The user's input text to generate a response for.
        system_prompt: Optional system prompt.
    Returns:
        A string containing the generated text.
    """
    # Combine user input and system prompt
    formatted_input = f"<s>[INST]{system_prompt} {user_input}[/INST]"

    # Encode the input text
    encodeds = tokenizer(formatted_input, return_tensors="pt", add_special_tokens=False)
    model_inputs = encodeds.to(device)

    # Generate a response using the model
    output = model.generate(
        **model_inputs,
        max_length=max_length,
        use_cache=True,
        early_stopping=True,
        bos_token_id=model.config.bos_token_id,
        eos_token_id=model.config.eos_token_id,
        pad_token_id=model.config.eos_token_id,
        temperature=0.1,
        do_sample=True
    )

    # Decode the response
    response_text = tokenizer.decode(output[0], skip_special_tokens=True)

    return response_text

# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Use the base model's ID
base_model_id = "mistralai/Mistral-7B-v0.1"
model_directory = "Tonic/mistralmed"

# Instantiate the Tokenizer
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", trust_remote_code=True, padding_side="left")
# tokenizer = AutoTokenizer.from_pretrained("Tonic/mistralmed", trust_remote_code=True, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'

# Specify the configuration class for the model
#model_config = AutoConfig.from_pretrained(base_model_id)

# Load the PEFT model with the specified configuration
#peft_model = AutoModelForCausalLM.from_pretrained(base_model_id, config=model_config)

# Load the PEFT model
peft_config = PeftConfig.from_pretrained("Tonic/mistralmed", token="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")
peft_model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", trust_remote_code=True)
peft_model = PeftModel.from_pretrained(peft_model, "Tonic/mistralmed", token="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")

class ChatBot:
    def __init__(self):
        self.history = []

    def predict(self, user_input, system_prompt="You are an expert medical analyst:"):
        # Combine user input and system prompt
        formatted_input = f"<s>[INST]{system_prompt} {user_input}[/INST]"

        # Encode user input
        user_input_ids = tokenizer.encode(formatted_input, return_tensors="pt")

        # Concatenate the user input with chat history
        if len(self.history) > 0:
            chat_history_ids = torch.cat([self.history, user_input_ids], dim=-1)
        else:
            chat_history_ids = user_input_ids

        # Generate a response using the PEFT model
        response = peft_model.generate(input_ids=chat_history_ids, max_length=512, pad_token_id=tokenizer.eos_token_id)

        # Update chat history
        self.history = chat_history_ids

        # Decode and return the response
        response_text = tokenizer.decode(response[0], skip_special_tokens=True)
        return response_text

bot = ChatBot()

title = "👋🏻Welcome to Tonic's MistralMed Chat🚀"
description = "You can use this Space to test out the current model (MistralMed) or duplicate this Space and use it for any other model on 🤗HuggingFace. Join me on Discord to build together."
examples = [["What is the proper treatment for buccal herpes?", "Please provide information on the most effective antiviral medications and home remedies for treating buccal herpes."]]

iface = gr.Interface(
    fn=bot.predict,
    title=title,
    description=description,
    examples=examples,
    inputs=["text", "text"],  # Take user input and system prompt separately
    outputs="text",
    theme="ParityError/Anime"
)

iface.launch()

Training Details

Training Data

MedQuad

Training Procedure

Dataset({
    features: ['qtype', 'Question', 'Answer'],
    num_rows: 16407
})

Preprocessing [optional]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )
    )
    (norm): MistralRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

Training Hyperparameters

  • Training regime:
config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head",
    ],
    bias="none",
    lora_dropout=0.05,  # Conventional
    task_type="CAUSAL_LM",
)

Speeds, Sizes, Times [optional]

  • trainable params: 21260288 || all params: 3773331456 || trainable%: 0.5634354746703705
  • TrainOutput(global_step=1000, training_loss=0.47226515007019043, metrics={'train_runtime': 3143.4141, 'train_samples_per_second': 2.545, 'train_steps_per_second': 0.318, 'total_flos': 1.75274075357184e+17, 'train_loss': 0.47226515007019043, 'epoch': 0.49})

Environmental Impact

Carbon emissions can be estimated using the Machine Learning Impact calculator presented in Lacoste et al. (2019).

  • Hardware Type: A100
  • Hours used: 1
  • Cloud Provider: Google
  • Compute Region: East1
  • Carbon Emitted: 0.09

Training Results

[1000/1000 52:20, Epoch 0/1]

Step Training Loss
50 0.474200
100 0.523300
150 0.484500
200 0.482800
250 0.498800
300 0.451800
350 0.491800
400 0.488000
450 0.472800
500 0.460400
550 0.464700
600 0.484800
650 0.474600
700 0.477900
750 0.445300
800 0.431300
850 0.461500
900 0.451200
950 0.470800
1000 0.454900

Model Architecture and Objective

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): MistralForCausalLM(
      (model): MistralModel(
        (embed_tokens): Embedding(32000, 4096)
        (layers): ModuleList(
          (0-31): 32 x MistralDecoderLayer(
            (self_attn): MistralAttention(
              (q_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
              )
              (k_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)
              )
              (v_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=1024, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)
              )
              (o_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
              )
              (rotary_emb): MistralRotaryEmbedding()
            )
            (mlp): MistralMLP(
              (gate_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=14336, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=4096, out_features=14336, bias=False)
              )
              (up_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=14336, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=4096, out_features=14336, bias=False)
              )
              (down_proj): Linear4bit(
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=14336, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (base_layer): Linear4bit(in_features=14336, out_features=4096, bias=False)
              )
              (act_fn): SiLUActivation()
            )
            (input_layernorm): MistralRMSNorm()
            (post_attention_layernorm): MistralRMSNorm()
          )
        )
        (norm): MistralRMSNorm()
      )
      (lm_head): Linear(
        in_features=4096, out_features=32000, bias=False
        (lora_dropout): ModuleDict(
          (default): Dropout(p=0.05, inplace=False)
        )
        (lora_A): ModuleDict(
          (default): Linear(in_features=4096, out_features=8, bias=False)
        )
        (lora_B): ModuleDict(
          (default): Linear(in_features=8, out_features=32000, bias=False)
        )
        (lora_embedding_A): ParameterDict()
        (lora_embedding_B): ParameterDict()
      )
    )
  )
)

Hardware

A100

Model Card Authors [optional]

Tonic

Model Card Contact

Tonic

Training procedure

The following bitsandbytes quantization config was used during training:

  • quant_method: bitsandbytes
  • load_in_8bit: False
  • load_in_4bit: True
  • llm_int8_threshold: 6.0
  • llm_int8_skip_modules: None
  • llm_int8_enable_fp32_cpu_offload: False
  • llm_int8_has_fp16_weight: False
  • bnb_4bit_quant_type: nf4
  • bnb_4bit_use_double_quant: True
  • bnb_4bit_compute_dtype: bfloat16

Framework versions

  • PEFT 0.6.0.dev0