Update utils/inference.py
Browse files- utils/inference.py +6 -7
utils/inference.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import torch
|
2 |
-
from transformers import
|
3 |
from peft import PeftModel
|
4 |
from typing import Iterator
|
5 |
from variables import SYSTEM, HUMAN, AI
|
@@ -8,7 +8,6 @@ from variables import SYSTEM, HUMAN, AI
|
|
8 |
def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
|
9 |
"""
|
10 |
Loads the tokenizer and chatbot model.
|
11 |
-
|
12 |
Args:
|
13 |
base_model (str): The base model to use (path to the model).
|
14 |
adapter_model (str): The LoRA model to use (path to LoRA model).
|
@@ -24,15 +23,15 @@ def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
|
|
24 |
device = "mps"
|
25 |
except:
|
26 |
pass
|
27 |
-
tokenizer =
|
28 |
if device == "cuda":
|
29 |
-
model =
|
30 |
base_model,
|
31 |
load_in_8bit=load_8bit,
|
32 |
torch_dtype=torch.float16
|
33 |
)
|
34 |
elif device == "mps":
|
35 |
-
model =
|
36 |
base_model,
|
37 |
device_map={"": device}
|
38 |
)
|
@@ -44,7 +43,7 @@ def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
|
|
44 |
torch_dtype=torch.float16,
|
45 |
)
|
46 |
else:
|
47 |
-
model =
|
48 |
base_model,
|
49 |
device_map={"": device},
|
50 |
low_cpu_mem_usage=True,
|
@@ -76,7 +75,7 @@ shared_state = State()
|
|
76 |
def decode(
|
77 |
input_ids: torch.Tensor,
|
78 |
model: PeftModel,
|
79 |
-
tokenizer:
|
80 |
stop_words: list,
|
81 |
max_length: int,
|
82 |
temperature: float = 1.0,
|
|
|
1 |
import torch
|
2 |
+
from transformers import LlamaTokenizer, LlamaForCausalLM
|
3 |
from peft import PeftModel
|
4 |
from typing import Iterator
|
5 |
from variables import SYSTEM, HUMAN, AI
|
|
|
8 |
def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
|
9 |
"""
|
10 |
Loads the tokenizer and chatbot model.
|
|
|
11 |
Args:
|
12 |
base_model (str): The base model to use (path to the model).
|
13 |
adapter_model (str): The LoRA model to use (path to LoRA model).
|
|
|
23 |
device = "mps"
|
24 |
except:
|
25 |
pass
|
26 |
+
tokenizer = LlamaTokenizer.from_pretrained(base_model)
|
27 |
if device == "cuda":
|
28 |
+
model = LlamaForCausalLM.from_pretrained(
|
29 |
base_model,
|
30 |
load_in_8bit=load_8bit,
|
31 |
torch_dtype=torch.float16
|
32 |
)
|
33 |
elif device == "mps":
|
34 |
+
model = LlamaForCausalLM.from_pretrained(
|
35 |
base_model,
|
36 |
device_map={"": device}
|
37 |
)
|
|
|
43 |
torch_dtype=torch.float16,
|
44 |
)
|
45 |
else:
|
46 |
+
model = LlamaForCausalLM.from_pretrained(
|
47 |
base_model,
|
48 |
device_map={"": device},
|
49 |
low_cpu_mem_usage=True,
|
|
|
75 |
def decode(
|
76 |
input_ids: torch.Tensor,
|
77 |
model: PeftModel,
|
78 |
+
tokenizer: LlamaTokenizer,
|
79 |
stop_words: list,
|
80 |
max_length: int,
|
81 |
temperature: float = 1.0,
|