tianyang commited on
Commit
ed9d322
1 Parent(s): b143c1f

Update utils/inference.py

Browse files
Files changed (1) hide show
  1. utils/inference.py +6 -7
utils/inference.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
  from peft import PeftModel
4
  from typing import Iterator
5
  from variables import SYSTEM, HUMAN, AI
@@ -8,7 +8,6 @@ from variables import SYSTEM, HUMAN, AI
8
  def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
9
  """
10
  Loads the tokenizer and chatbot model.
11
-
12
  Args:
13
  base_model (str): The base model to use (path to the model).
14
  adapter_model (str): The LoRA model to use (path to LoRA model).
@@ -24,15 +23,15 @@ def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
24
  device = "mps"
25
  except:
26
  pass
27
- tokenizer = AutoTokenizer.from_pretrained(base_model)
28
  if device == "cuda":
29
- model = AutoModelForCausalLM.from_pretrained(
30
  base_model,
31
  load_in_8bit=load_8bit,
32
  torch_dtype=torch.float16
33
  )
34
  elif device == "mps":
35
- model = AutoModelForCausalLM.from_pretrained(
36
  base_model,
37
  device_map={"": device}
38
  )
@@ -44,7 +43,7 @@ def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
44
  torch_dtype=torch.float16,
45
  )
46
  else:
47
- model = AutoModelForCausalLM.from_pretrained(
48
  base_model,
49
  device_map={"": device},
50
  low_cpu_mem_usage=True,
@@ -76,7 +75,7 @@ shared_state = State()
76
  def decode(
77
  input_ids: torch.Tensor,
78
  model: PeftModel,
79
- tokenizer: AutoTokenizer,
80
  stop_words: list,
81
  max_length: int,
82
  temperature: float = 1.0,
 
1
  import torch
2
+ from transformers import LlamaTokenizer, LlamaForCausalLM
3
  from peft import PeftModel
4
  from typing import Iterator
5
  from variables import SYSTEM, HUMAN, AI
 
8
  def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
9
  """
10
  Loads the tokenizer and chatbot model.
 
11
  Args:
12
  base_model (str): The base model to use (path to the model).
13
  adapter_model (str): The LoRA model to use (path to LoRA model).
 
23
  device = "mps"
24
  except:
25
  pass
26
+ tokenizer = LlamaTokenizer.from_pretrained(base_model)
27
  if device == "cuda":
28
+ model = LlamaForCausalLM.from_pretrained(
29
  base_model,
30
  load_in_8bit=load_8bit,
31
  torch_dtype=torch.float16
32
  )
33
  elif device == "mps":
34
+ model = LlamaForCausalLM.from_pretrained(
35
  base_model,
36
  device_map={"": device}
37
  )
 
43
  torch_dtype=torch.float16,
44
  )
45
  else:
46
+ model = LlamaForCausalLM.from_pretrained(
47
  base_model,
48
  device_map={"": device},
49
  low_cpu_mem_usage=True,
 
75
  def decode(
76
  input_ids: torch.Tensor,
77
  model: PeftModel,
78
+ tokenizer: LlamaTokenizer,
79
  stop_words: list,
80
  max_length: int,
81
  temperature: float = 1.0,