dh-mc commited on
Commit
2847edc
·
1 Parent(s): 6b4da82

Update llm_utils.py

Browse files
Files changed (1) hide show
  1. llm_toolkit/llm_utils.py +12 -1
llm_toolkit/llm_utils.py CHANGED
@@ -10,6 +10,17 @@ from transformers import (
10
  from tqdm import tqdm
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
13
  def load_model(
14
  model_name,
15
  dtype=torch.bfloat16,
@@ -22,7 +33,7 @@ def load_model(
22
  if adapter_name_or_path and using_llama_factory:
23
  from llamafactory.chat import ChatModel
24
 
25
- template = "llama3" if "llama-3" in model_name.lower() else "chatml"
26
 
27
  args = dict(
28
  model_name_or_path=model_name,
 
10
  from tqdm import tqdm
11
 
12
 
13
+ def get_template(model_name):
14
+ model_name = model_name.lower()
15
+ if "llama-3" in model_name:
16
+ return "llama3"
17
+ if "internlm" in model_name:
18
+ return "intern2"
19
+ if "glm" in model_name:
20
+ return "glm4"
21
+ return "chatml"
22
+
23
+
24
  def load_model(
25
  model_name,
26
  dtype=torch.bfloat16,
 
33
  if adapter_name_or_path and using_llama_factory:
34
  from llamafactory.chat import ChatModel
35
 
36
+ template = get_template(model_name)
37
 
38
  args = dict(
39
  model_name_or_path=model_name,