Edit model card

GPT-J-6B-MedDoctor πŸ‘¨β€βš•οΈ

This chatbot model was built via Parameter-Efficient Fine-Tuning of GPT-J-6B on all 16.3k rows of Medical Data. Finetuning was executed on a single A100 (40 GB) for roughly 1 day 7 hours.

Example prompts and responses

Example 1:

User:

How can I cope with the stress and anxiety I've been experiencing lately?

GPT-J-6B-MedDoctor-PEFT:

  1. Talk to friends and family to get support.
  2. Do things you enjoy, such as sports, hobbies, or volunteering.
  3. Learn relaxation techniques for when you are stressed.
  4. Take short walks or get outside as often as you can to get fresh air.

Example 2:

User:

I've been struggling with insomnia. What are some strategies to improve my sleep quality?

GPT-J-6B-MedDoctor-PEFT:

Learn about sleep hygiene. You can read about this topic at the Sleep Education and Therapy Page. Sleep hygiene can help you

  1. get the right amount of sleep at the right time
  2. stay relaxed and comfortable throughout the night
  3. fall asleep and stay asleep quickly
  4. remember to call your doctor if you suddenly feel better and you have a fever
  5. know when your next doctor visit is
  6. don't take naps at night or during the day if you take medication that affects your sleep
  7. don't drink caffeine or alcohol before you go to sleep or if you feel tired the next day
  8. don't smoke
  9. stop using caffeine or alcohol, or drink decaffeinated coffee

Example 3:

User:

I have a chronic illness, and it's impacting my emotional well-being. How can I manage the emotional toll it's taking on me?

GPT-J-6B-MedDoctor-PEFT:

Think about how often you experience feelings such as sadness, worry, anger, or guilt. For example, do you feel frustrated or embarrassed when you cannot make others happy? Do you experience frequent feelings of sadness, despair, and anger? If so, take time to talk to a health professional about how to develop a better understanding of how and why you feel these emotions. Talk to a mental health professional about how to change destructive thinking patterns, Virgin Island Blueprints for Healthy Living guides for both men and women. A group session with other members of your support network may also help.

Example 4:

User:

what are the preventive measures of diabetes?

GPT-J-6B-MedDoctor-PEFT:

Get physical activity. Make changes in your lifestyle such as a eating a healthly diet, quitting smoke, and getting regular physical activity. Getting physical activity will help you lose weight and keep your blood glucose levels down. Talk with your health care provider before you start new exercise program. You may need to adjust your medication or try a different medicine.


Model Description

GPT-J 6B is a transformer model trained using Ben Wang's Mesh Transformer JAX. "GPT-J" refers to the class of model, while "6B" represents the number of trainable parameters.

Hyperparameter Value
nparametersn_{parameters} 6053381344
nlayersn_{layers} 28*
dmodeld_{model} 4096
dffd_{ff} 16384
nheadsn_{heads} 16
dheadd_{head} 256
nctxn_{ctx} 2048
nvocabn_{vocab} 50257/50400† (same tokenizer as GPT-2/3)
Positional Encoding Rotary Position Embedding (RoPE)
RoPE Dimensions 64

Finetuning Description

This model was trained on a single A100 (40 GB) for about 1 Day 7 hours.

Run: July 23, 2023

  • args: {'lr': 0.001, 'num_epochs': 10, 'seed': 42}
  • log_of_epoch_01:{'eval_loss': 0.9936667084693909, 'eval_runtime': 450.8767, 'eval_samples_per_second': 7.246, 'eval_steps_per_second': 0.455, 'epoch': 1.0}
  • log_of_epoch_02:{'eval_loss': 0.9738781452178955, 'eval_runtime': 447.3755, 'eval_samples_per_second': 7.303, 'eval_steps_per_second': 0.458, 'epoch': 2.0}
  • log_of_epoch_03:{'eval_loss': 0.9600604176521301, 'eval_runtime': 441.2023, 'eval_samples_per_second': 7.405, 'eval_steps_per_second': 0.465, 'epoch': 3.0}
  • log_of_epoch_04:{'eval_loss': 0.9634631872177124, 'eval_runtime': 441.53, 'eval_samples_per_second': 7.399, 'eval_steps_per_second': 0.464, 'epoch': 4.0}
  • log_of_epoch_05:{'eval_loss': 0.961345374584198, 'eval_runtime': 441.3189, 'eval_samples_per_second': 7.403, 'eval_steps_per_second': 0.465, 'epoch': 5.0}
  • log_of_epoch_06:{'eval_loss': 0.9655225872993469, 'eval_runtime': 441.9449, 'eval_samples_per_second': 7.392, 'eval_steps_per_second': 0.464, 'epoch': 6.0}
  • log_of_epoch_07:{'eval_loss': 0.9740663766860962, 'eval_runtime': 441.7603, 'eval_samples_per_second': 7.395, 'eval_steps_per_second': 0.464, 'epoch': 7.0}
  • log_of_epoch_08:{'eval_loss': 0.9907786846160889, 'eval_runtime': 441.6064, 'eval_samples_per_second': 7.398, 'eval_steps_per_second': 0.464, 'epoch': 8.0}
  • log_of_epoch_09:{'eval_loss': 1.0046937465667725, 'eval_runtime': 441.9242, 'eval_samples_per_second': 7.393, 'eval_steps_per_second': 0.464, 'epoch': 9.0}
  • log_of_epoch_10:{'train_runtime': 118063.0495, 'train_samples_per_second': 1.107, 'train_steps_per_second': 0.069, 'train_loss': 0.7715376593637642, 'epoch': 10.0}

PreTraining Data

For more details on the pretraining process, see GPT-J-6B.

The data was tokenized using the GPT-J-6B tokenizer.

Training procedure

The following bitsandbytes quantization config was used during training:

  • load_in_8bit: True
  • load_in_4bit: False
  • llm_int8_threshold: 6.0
  • llm_int8_skip_modules: None
  • llm_int8_enable_fp32_cpu_offload: False
  • llm_int8_has_fp16_weight: False
  • bnb_4bit_quant_type: fp4
  • bnb_4bit_use_double_quant: False
  • bnb_4bit_compute_dtype: float32

Limitations and Biases

The following language is modified from GPT-J-6B

This model can produce factually incorrect output, and should not be relied on to produce factually accurate information. This model was trained on various public datasets. While great efforts have been taken to clean the pretraining data, it is possible that this model could generate lewd, biased or otherwise offensive outputs.

How to Use

Install and import the package dependencies:

!pip install -q -U huggingface_hub peft transformers torch accelerate bitsandbytes
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

Basic model loading:

INTRO = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
INSTRUCTION_FORMAT = (
    """{intro} ### Instruction: {instruction} ### Input: {input} ### Response: """
)

def load_model_tokenizer_for_generate(pretrained_model_name_or_path: str):
    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path, padding_side="left"
    )
    model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path)
    return model, tokenizer

Once loaded, the model and tokenizer can be used with the following code:

def generate_response(
    instruction: str,
    input_text: str,
    *,
    model,
    tokenizer,
    do_sample: bool = True,
    max_new_tokens: int = 500,
    top_p: float = 0.92,
    top_k: int = 0,
    **kwargs,
) -> str:
    input_ids = tokenizer(
        INSTRUCTION_FORMAT.format(
            intro=INTRO, instruction=instruction, input=input_text
        ),
        return_tensors="pt",
    ).input_ids
    gen_tokens = model.generate(
        input_ids=input_ids,
        pad_token_id=tokenizer.pad_token_id,
        do_sample=do_sample,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        top_k=top_k,
        **kwargs,
    )
    decoded = tokenizer.batch_decode(gen_tokens)[0]

    # The response appears after "### Response:".  The model has been trained to append "### End" at the end.
    m = re.search(r"#+\s*Response:\s*(.+?)#+\s*End", decoded, flags=re.DOTALL)

    response = None
    if m:
        response = m.group(1).strip()
    else:
        # The model might not generate the "### End" sequence before reaching the max tokens.  In this case, return
        # everything after "### Response:".
        m = re.search(r"#+\s*Response:\s*(.+)", decoded, flags=re.DOTALL)
        if m:
            response = m.group(1).strip()
        else:
            print(f"Failed to find response in:\n{decoded}")

    return response

We can now generate text! For example:

if __name__ == "__main__":
    base_model = "EleutherAI/gpt-j-6B"
    peft_model_id = "ghimiresunil/MedDoctor"
    config = PeftConfig.from_pretrained(peft_model_id)
    model = AutoModelForCausalLM.from_pretrained(base_model, return_dict=True)
    trained_model = PeftModel.from_pretrained(model, peft_model_id)

    tokenizer = AutoTokenizer.from_pretrained(base_model)

    print("Welcome to the response generation program!")
    while True:
        instruction = "If you are a doctor, please answer the medical questions based on user's query"
        input_text = input("Enter the input text: ")
        response = generate_response(
            instruction=instruction,
            input_text=input_text,
            model=trained_model,
            tokenizer=tokenizer,
        )
        print('*' * 100)
        print("Generated Response:")
        print(response)
        print('*' * 100)

        continue_generation = input("Do you want to continue (yes/no)? ").lower()
        if continue_generation != "yes":
            print("Exiting the response generation program.")
            break

Acknowledgements

This model was finetuned by Sunil Ghimire on July 23, 2023 and is intended primarily for research purposes.

Disclaimer

The license on this model does not constitute legal advice. We are not responsible for the actions of third parties who use this model. Please cosult an attorney before using this model for commercial purposes.

Citation and Related Information for GPT-J-6b

To cite this model:

@misc{gpt-j,
  author = {Wang, Ben and Komatsuzaki, Aran},
  title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
  howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
  year = 2021,
  month = May
}

To cite the codebase that trained this model:

@misc{mesh-transformer-jax,
  author = {Wang, Ben},
  title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
  howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
  year = 2021,
  month = May
}

Framework versions

  • PEFT 0.4.0
Downloads last month
2
Inference API
Inference API (serverless) has been turned off for this model.