File size: 811 Bytes
f732d7c
d6c0cbe
f732d7c
b555022
 
 
 
0e9fa83
 
f732d7c
 
 
 
b555022
 
f732d7c
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import transformers
from peft import PeftModel

assert (
    "LlamaTokenizer" in transformers._import_structure["models.llama"]
), "LLaMA is now in HuggingFace's main branch.\nPlease reinstall it: pip uninstall transformers && pip install git+https://github.com/huggingface/transformers.git"
from transformers import LlamaForCausalLM  # noqa
from transformers import LlamaTokenizer  # noqa

BASE_MODEL = "decapoda-research/llama-13b-hf"
LORA_WEIGHTS = "izumi-lab/llama-13b-japanese-lora-v0-1ep"

tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)
model = LlamaForCausalLM.from_pretrained(
    BASE_MODEL,
    load_in_8bit=False,
    torch_dtype=torch.float16,
    device_map="auto",
)
model = PeftModel.from_pretrained(
    model, LORA_WEIGHTS, torch_dtype=torch.float16, use_auth_token=True
)