metadata
library_name: peft
license: apache-2.0
language:
- en
tags:
- medical
- llama2
Inference
import torch
from peft import PeftModel, PeftConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
peft_model_id = "qanastek/MedAlpaca-LLaMa2-7B"
config = PeftConfig.from_pretrained(peft_model_id)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
quantization_config=bnb_config,
use_auth_token=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
model = PeftModel.from_pretrained(model, peft_model_id)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def generate(
model: AutoModelForCausalLM,
tokenizer: AutoTokenizer,
prompt: str,
max_new_tokens: int = 128,
temperature: int = 1.0,
) -> str:
inputs = tokenizer([prompt], return_tensors="pt", return_token_type_ids=False).to(device)
# with torch.autocast("cuda", dtype=torch.bfloat16):
response = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
return_dict_in_generate=True,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
return tokenizer.decode(response["sequences"][0], skip_special_tokens=True)[len(prompt):]
prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. ### Instruction: We are giving you a scientific question (easy level) and five answers options (associated to « A », « B », « C », « D », « E »). Your task is to find the correct(s) answer(s) based on scientific facts, knowledge and reasoning. Don't generate anything other than one of the following characters : 'A B C D E'. ### Input: Among the following propositions, only one is correct; which? The most active thyroid hormone at the cellular level is: (A) Triiodothyronine (T3) (B) Tetraiodothyronine (T4) (C) 3,3',5'-triiodothyronine (rT3) (D) Thyroglobulin ( E) Triiodothyroacetic acid ### Response:\n"
response = generate(
model,
tokenizer,
prompt,
max_new_tokens=500,
temperature=0.92,
)
print(response)
Training procedure
Model: meta-llama/Llama-2-7b-hf
The following bitsandbytes
quantization config was used during training:
- load_in_8bit: False
- load_in_4bit: True
- llm_int8_threshold: 6.0
- llm_int8_skip_modules: None
- llm_int8_enable_fp32_cpu_offload: False
- llm_int8_has_fp16_weight: False
- bnb_4bit_quant_type: nf4
- bnb_4bit_use_double_quant: False
- bnb_4bit_compute_dtype: float16
Framework versions
- PEFT 0.4.0