arkii commited on
Commit
3da7b5e
1 Parent(s): 481948b

Create chatglm_langchain.py

Browse files
Files changed (1) hide show
  1. chatglm_langchain.py +85 -0
chatglm_langchain.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Mapping, Optional
2
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
3
+ from langchain.llms.base import LLM
4
+ import chatglm_cpp
5
+
6
+ from langchain import PromptTemplate, LLMChain
7
+ from langchain.callbacks.manager import CallbackManager
8
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
9
+
10
+ DEFAULT_MODEL_PATH = "chatglm2-6b-ggml.q8_0.bin"
11
+
12
+ callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
13
+ pipeline = chatglm_cpp.Pipeline(DEFAULT_MODEL_PATH)
14
+
15
+ class ChatGLM(LLM):
16
+ temperature: float = 0.7
17
+ base_model: str = DEFAULT_MODEL_PATH
18
+ max_length: int = 2048
19
+ verbose: bool = False
20
+ streaming: bool = False
21
+ top_p: float = 0.9
22
+ top_k: int = 0
23
+ max_context_length: int = 512
24
+ threads: int = 0
25
+
26
+ @property
27
+ def _llm_type(self) -> str:
28
+ return "chatglm"
29
+
30
+ def _call(self, prompt: str, stop: Optional[List[str]] = None,
31
+ run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str:
32
+ if stop is not None:
33
+ raise ValueError("stop kwargs are not permitted.")
34
+ history = [prompt]
35
+ response = ""
36
+ if self.streaming:
37
+ for piece in pipeline.stream_chat(
38
+ history,
39
+ max_length=self.max_length,
40
+ max_context_length=self.max_context_length,
41
+ do_sample=self.temperature > 0,
42
+ top_k=self.top_k,
43
+ top_p=self.top_p,
44
+ temperature=self.temperature,
45
+ num_threads=self.threads,
46
+ ):
47
+ response += piece
48
+ return response
49
+ # yield piece
50
+ # response += piece
51
+ # history.append(response)
52
+ # yield response
53
+ else:
54
+ response = pipeline.chat(
55
+ history,
56
+ max_length=self.max_length,
57
+ max_context_length=self.max_context_length,
58
+ do_sample=self.temperature > 0,
59
+ top_k=self.top_k,
60
+ top_p=self.top_p,
61
+ temperature=self.temperature,
62
+ num_threads=self.threads,
63
+ )
64
+ return response
65
+
66
+ @property
67
+ def _identifying_params(self) -> Mapping[str, Any]:
68
+ """Get the identifying parameters."""
69
+ return {"temperature": self.temperature,
70
+ "base_model": self.base_model,
71
+ "max_length": self.max_length,
72
+ "verbose": self.verbose,
73
+ "streaming": self.streaming,
74
+ "top_p": self.top_p,
75
+ "top_k": self.top_k,
76
+ "max_context_length": self.max_context_length,
77
+ "threads": self.threads}
78
+
79
+
80
+ template = "小明的妈妈有两个孩子,一个叫大明 {question}"
81
+ prompt = PromptTemplate(template=template, input_variables=["question"])
82
+ question = "另外一个叫什么?"
83
+ llm = ChatGLM(streaming=False, callback_manager=callback_manager, show_progress=True)
84
+ llm_chain = LLMChain(prompt=prompt, llm=llm)
85
+ print(llm_chain.run(question))