hiwei commited on
Commit
780e3dc
1 Parent(s): 36db0fa

fix: wrong ChatGLM2 inti

Browse files
Files changed (1) hide show
  1. chatglm2_6b/modelClient.py +3 -3
chatglm2_6b/modelClient.py CHANGED
@@ -18,8 +18,9 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
18
 
19
  class ChatGLM2(object):
20
  def __init__(self, model_path=None):
21
- if not model_path:
22
- self.model_path = DEFAULT_MODEL_PATH
 
23
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
24
  model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).half().cuda()
25
  self.model = model.eval()
@@ -79,4 +80,3 @@ class ChatGLM2(object):
79
  max_length=max_length, do_sample=do_sample, top_p=top_p, temperature=temperature)
80
  for resp, new_history in stream:
81
  yield resp, new_history
82
-
 
18
 
19
  class ChatGLM2(object):
20
  def __init__(self, model_path=None):
21
+ self.model_path = DEFAULT_MODEL_PATH
22
+ if model_path:
23
+ self.model_path = model_path
24
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True)
25
  model = AutoModel.from_pretrained(self.model_path, trust_remote_code=True).half().cuda()
26
  self.model = model.eval()
 
80
  max_length=max_length, do_sample=do_sample, top_p=top_p, temperature=temperature)
81
  for resp, new_history in stream:
82
  yield resp, new_history