PEFT
English
medical
mistralmed / README.md
Tonic's picture
Update README.md
e0163b4
|
raw
history blame
No virus
13.7 kB
metadata
library_name: peft
base_model: mistralai/Mistral-7B-v0.1
license: mit
datasets:
  - keivalya/MedQuad-MedicalQnADataset
language:
  - en
metrics:
  - bertscore
tags:
  - medical

Model Card for Model ID

This is a medicine-focussed mistral fine tuned using keivalya/MedQuad-MedicalQnADataset

Model Details

Model Description

Trying to get better at medical Q & A

  • Developed by: Tonic
  • Shared by [optional]: Tonic
  • Model type: Mistral Fine-Tune
  • Language(s) (NLP): English
  • License: MIT2.0
  • Finetuned from model [optional]: mistralai/Mistral-7B-v0.1

Model Sources [optional]

Uses

This model can be used the same way you normally use mistral

Direct Use

This model can do better in medical question and answer scenarios.

Downstream Use [optional]

This model is intended to be further fine tuned.

Recommendations

  • Do Not Use As Is
  • Fine Tune This Model Further
  • For Educational Purposes Only
  • Benchmark your model usage
  • Evaluate the model before use

Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.

How to Get Started with the Model

Use the code below to get started with the model.

Tonic/MistralMED_Chat

from transformers import  AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
from peft import PeftModel, PeftConfig
import torch
import gradio as gr
  
  
# Use the base model's ID
base_model_id = "mistralai/Mistral-7B-v0.1"
model_directory = "Tonic/mistralmed"

# Instantiate the Models
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1", trust_remote_code=True, padding_side="left")
#tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'


# Specify the configuration class for the model
#model_config = AutoConfig.from_pretrained(base_model_id)

# Load the PEFT model with the specified configuration
#peft_model = AutoModelForCausalLM.from_pretrained(base_model_id, config=model_config)

# Load the PEFT model
peft_config = PeftConfig.from_pretrained("Tonic/mistralmed", token="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")
peft_model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", trust_remote_code=True)
peft_model = PeftModel.from_pretrained(peft_model, "Tonic/mistralmed", token="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")

class ChatBot:
    def __init__(self):
        self.history = []

    def predict(self, input):
        # Encode user input
        user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors="pt")

        # Concatenate the user input with chat history
        if len(self.history) > 0:
            chat_history_ids = torch.cat([self.history, user_input_ids], dim=-1)
        else:
            chat_history_ids = user_input_ids

        # Generate a response using the PEFT model
        response = peft_model.generate(input_ids=chat_history_ids, max_length=512, pad_token_id=tokenizer.eos_token_id)

        # Update chat history
        self.history = chat_history_ids

        # Decode and return the response
        response_text = tokenizer.decode(response[0], skip_special_tokens=True)
        return response_text

bot = ChatBot()

title = "👋🏻Welcome to Tonic's MistralMed Chat🚀"
description = "You can use this Space to test out the current model (MistralMed) or duplicate this Space and use it for any other model on 🤗HuggingFace. Join me on Discord to build together."
examples = [["What is the boiling point of nitrogen"]]

iface = gr.Interface(
    fn=bot.predict,
    title=title,
    description=description,
    examples=examples,
    inputs="text",
    outputs="text",
    theme="ParityError/Anime"
)

iface.launch()

Training Details

Training Data

MedQuad

Training Procedure

Dataset({ features: ['qtype', 'Question', 'Answer'], num_rows: 16407 })

Preprocessing [optional]

MistralForCausalLM( (model): MistralModel( (embed_tokens): Embedding(32000, 4096) (layers): ModuleList( (0-31): 32 x MistralDecoderLayer( (self_attn): MistralAttention( (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False) (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False) (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False) (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False) (rotary_emb): MistralRotaryEmbedding() ) (mlp): MistralMLP( (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False) (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False) (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False) (act_fn): SiLUActivation() ) (input_layernorm): MistralRMSNorm() (post_attention_layernorm): MistralRMSNorm() ) ) (norm): MistralRMSNorm() ) (lm_head): Linear(in_features=4096, out_features=32000, bias=False) )

Training Hyperparameters

  • Training regime: config = LoraConfig( r=8, lora_alpha=16, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head", ], bias="none", lora_dropout=0.05, # Conventional task_type="CAUSAL_LM", )

Speeds, Sizes, Times [optional]

  • trainable params: 21260288 || all params: 3773331456 || trainable%: 0.5634354746703705
  • TrainOutput(global_step=1000, training_loss=0.47226515007019043, metrics={'train_runtime': 3143.4141, 'train_samples_per_second': 2.545, 'train_steps_per_second': 0.318, 'total_flos': 1.75274075357184e+17, 'train_loss': 0.47226515007019043, 'epoch': 0.49})

Environmental Impact

Carbon emissions can be estimated using the Machine Learning Impact calculator presented in Lacoste et al. (2019).

  • Hardware Type: A100
  • Hours used: 1
  • Cloud Provider: Google
  • Compute Region: East1
  • Carbon Emitted: 0.09

Training Results

[1000/1000 52:20, Epoch 0/1]

Step Training Loss
50 0.474200
100 0.523300
150 0.484500
200 0.482800
250 0.498800
300 0.451800
350 0.491800
400 0.488000
450 0.472800
500 0.460400
550 0.464700
600 0.484800
650 0.474600
700 0.477900
750 0.445300
800 0.431300
850 0.461500
900 0.451200
950 0.470800
1000 0.454900

Model Architecture and Objective

PeftModelForCausalLM( (base_model): LoraModel( (model): MistralForCausalLM( (model): MistralModel( (embed_tokens): Embedding(32000, 4096) (layers): ModuleList( (0-31): 32 x MistralDecoderLayer( (self_attn): MistralAttention( (q_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=4096, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False) ) (k_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=1024, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False) ) (v_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=1024, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False) ) (o_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=4096, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False) ) (rotary_emb): MistralRotaryEmbedding() ) (mlp): MistralMLP( (gate_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=14336, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=4096, out_features=14336, bias=False) ) (up_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=14336, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=4096, out_features=14336, bias=False) ) (down_proj): Linear4bit( (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=14336, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=4096, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() (base_layer): Linear4bit(in_features=14336, out_features=4096, bias=False) ) (act_fn): SiLUActivation() ) (input_layernorm): MistralRMSNorm() (post_attention_layernorm): MistralRMSNorm() ) ) (norm): MistralRMSNorm() ) (lm_head): Linear( in_features=4096, out_features=32000, bias=False (lora_dropout): ModuleDict( (default): Dropout(p=0.05, inplace=False) ) (lora_A): ModuleDict( (default): Linear(in_features=4096, out_features=8, bias=False) ) (lora_B): ModuleDict( (default): Linear(in_features=8, out_features=32000, bias=False) ) (lora_embedding_A): ParameterDict() (lora_embedding_B): ParameterDict() ) ) ) )

Hardware

A100

Model Card Authors [optional]

Tonic

Model Card Contact

Tonic

Training procedure

The following bitsandbytes quantization config was used during training:

  • quant_method: bitsandbytes
  • load_in_8bit: False
  • load_in_4bit: True
  • llm_int8_threshold: 6.0
  • llm_int8_skip_modules: None
  • llm_int8_enable_fp32_cpu_offload: False
  • llm_int8_has_fp16_weight: False
  • bnb_4bit_quant_type: nf4
  • bnb_4bit_use_double_quant: True
  • bnb_4bit_compute_dtype: bfloat16

Framework versions

  • PEFT 0.6.0.dev0