alexghergh
commited on
Commit
•
1146dae
1
Parent(s):
2c2e591
Update inference.py
Browse files- inference.py +4 -6
inference.py
CHANGED
@@ -6,17 +6,15 @@ from peft import PeftModel, PeftConfig
|
|
6 |
import torch
|
7 |
|
8 |
orig_checkpoint = 'google/gemma-2b'
|
9 |
-
checkpoint = '
|
10 |
HF_TOKEN = ''
|
11 |
PROMPT = 'Salut, ca sa imi schimb buletinul pot sa'
|
12 |
|
13 |
-
seq_len =
|
14 |
|
15 |
# load original model first
|
16 |
tokenizer = AutoTokenizer.from_pretrained(orig_checkpoint, token=HF_TOKEN)
|
17 |
-
|
18 |
-
config = PeftConfig.from_pretrained(checkpoint)
|
19 |
-
model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, token=HF_TOKEN)
|
20 |
|
21 |
# then merge trained QLoRA weights
|
22 |
model = PeftModel.from_pretrained(model, checkpoint)
|
@@ -28,4 +26,4 @@ model = model.cuda()
|
|
28 |
inputs = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
|
29 |
outputs = model.generate(inputs, max_new_tokens=seq_len)
|
30 |
|
31 |
-
print(tokenizer.decode(outputs[0]))
|
|
|
6 |
import torch
|
7 |
|
8 |
orig_checkpoint = 'google/gemma-2b'
|
9 |
+
checkpoint = '.'
|
10 |
HF_TOKEN = ''
|
11 |
PROMPT = 'Salut, ca sa imi schimb buletinul pot sa'
|
12 |
|
13 |
+
seq_len = 256
|
14 |
|
15 |
# load original model first
|
16 |
tokenizer = AutoTokenizer.from_pretrained(orig_checkpoint, token=HF_TOKEN)
|
17 |
+
model = AutoModelForCausalLM.from_pretrained(orig_checkpoint, token=HF_TOKEN)
|
|
|
|
|
18 |
|
19 |
# then merge trained QLoRA weights
|
20 |
model = PeftModel.from_pretrained(model, checkpoint)
|
|
|
26 |
inputs = tokenizer.encode(PROMPT, return_tensors="pt").cuda()
|
27 |
outputs = model.generate(inputs, max_new_tokens=seq_len)
|
28 |
|
29 |
+
print(tokenizer.decode(outputs[0]))
|