File size: 3,138 Bytes
3da7b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83ef88f
3da7b5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from typing import Any, List, Mapping, Optional
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
import chatglm_cpp

from langchain import PromptTemplate, LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler

DEFAULT_MODEL_PATH = "chatglm2-6b-ggml.q8_0.bin"

callback_manager = CallbackManager([StreamingStdOutCallbackHandler()])
pipeline = chatglm_cpp.Pipeline(DEFAULT_MODEL_PATH)

class ChatGLM(LLM):
    temperature: float = 0.7
    base_model: str = DEFAULT_MODEL_PATH
    max_length: int = 2048
    verbose: bool = False
    streaming: bool = False
    top_p: float = 0.9
    top_k: int = 0
    max_context_length: int = 512
    threads: int = 0

    @property
    def _llm_type(self) -> str:
        return "chatglm"

    def _call(self, prompt: str, stop: Optional[List[str]] = None,
              run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str:
        if stop is not None:
            raise ValueError("stop kwargs are not permitted.")
        print("Prompt: ", prompt)
        history = [prompt]
        response = ""
        if self.streaming:
            for piece in pipeline.stream_chat(
                    history,
                    max_length=self.max_length,
                    max_context_length=self.max_context_length,
                    do_sample=self.temperature > 0,
                    top_k=self.top_k,
                    top_p=self.top_p,
                    temperature=self.temperature,
                    num_threads=self.threads,
            ):
                response += piece
            return response
            #     yield piece
            #     response += piece
            # history.append(response)
            # yield response
        else:
            response = pipeline.chat(
                history,
                max_length=self.max_length,
                max_context_length=self.max_context_length,
                do_sample=self.temperature > 0,
                top_k=self.top_k,
                top_p=self.top_p,
                temperature=self.temperature,
                num_threads=self.threads,
            )
            return response

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        """Get the identifying parameters."""
        return {"temperature": self.temperature,
                "base_model": self.base_model,
                "max_length": self.max_length,
                "verbose": self.verbose,
                "streaming": self.streaming,
                "top_p": self.top_p,
                "top_k": self.top_k,
                "max_context_length": self.max_context_length,
                "threads": self.threads}


template = "小明的妈妈有两个孩子,一个叫大明 {question}"
prompt = PromptTemplate(template=template, input_variables=["question"])
question = "另外一个叫什么?"
llm = ChatGLM(streaming=False, callback_manager=callback_manager, show_progress=True)
llm_chain = LLMChain(prompt=prompt, llm=llm)
print(llm_chain.run(question))