MedAlpaca-LLaMa2-7B / README.md
qanastek's picture
Update README.md
3b2c23c
metadata
library_name: peft
license: apache-2.0
language:
  - en
tags:
  - medical
  - llama2

Inference

import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

peft_model_id = "qanastek/MedAlpaca-LLaMa2-7B"
config = PeftConfig.from_pretrained(peft_model_id)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    quantization_config=bnb_config,
    use_auth_token=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token

model = PeftModel.from_pretrained(model, peft_model_id)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def generate(
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    prompt: str,
    max_new_tokens: int = 128,
    temperature: int = 1.0,
) -> str:

    inputs = tokenizer([prompt], return_tensors="pt", return_token_type_ids=False).to(device)
    
    # with torch.autocast("cuda", dtype=torch.bfloat16):
    response = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        return_dict_in_generate=True,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )

    return tokenizer.decode(response["sequences"][0], skip_special_tokens=True)[len(prompt):]

prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. ### Instruction: We are giving you a scientific question (easy level) and five answers options (associated to « A », « B », « C », « D », « E »). Your task is to find the correct(s) answer(s) based on scientific facts, knowledge and reasoning. Don't generate anything other than one of the following characters : 'A B C D E'. ### Input: Among the following propositions, only one is correct; which? The most active thyroid hormone at the cellular level is: (A) Triiodothyronine (T3) (B) Tetraiodothyronine (T4) (C) 3,3',5'-triiodothyronine (rT3) (D) Thyroglobulin ( E) Triiodothyroacetic acid ### Response:\n"

response = generate(
    model,
    tokenizer,
    prompt,
    max_new_tokens=500,
    temperature=0.92,
)

print(response)

Training procedure

Model: meta-llama/Llama-2-7b-hf

The following bitsandbytes quantization config was used during training:

  • load_in_8bit: False
  • load_in_4bit: True
  • llm_int8_threshold: 6.0
  • llm_int8_skip_modules: None
  • llm_int8_enable_fp32_cpu_offload: False
  • llm_int8_has_fp16_weight: False
  • bnb_4bit_quant_type: nf4
  • bnb_4bit_use_double_quant: False
  • bnb_4bit_compute_dtype: float16

Framework versions

  • PEFT 0.4.0