mwz commited on
Commit
2907c25
1 Parent(s): 088a354

Upload 5 files

Browse files
llama-2-7b-chat.ggmlv3.q2_K.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45833e0b59c8fe80676c664f556031fc411da8856e0716ac7b8ed201b7221c08
3
+ size 2866807424
llama_cpp_chat_completion_wrapper.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from llama_cpp import Llama
2
+ from functools import partial
3
+ from typing import List, Literal, TypedDict, Callable
4
+
5
+ Role = Literal["system", "user", "assistant"]
6
+
7
+
8
+ class Message(TypedDict):
9
+ role: Role
10
+ content: str
11
+
12
+
13
+ B_INST, E_INST = "[INST]", "[/INST]"
14
+ B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
15
+ DEFAULT_SYSTEM_PROMPT = """\
16
+ You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
17
+
18
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
19
+
20
+
21
+ def _llama2_format_messages(messages: List[Message], tokenizer_encode: Callable) -> List[int]:
22
+ if messages[0]["role"] != "system":
23
+ messages = [
24
+ {
25
+ "role": "system",
26
+ "content": DEFAULT_SYSTEM_PROMPT,
27
+ }
28
+ ] + messages
29
+ messages = [
30
+ {
31
+ "role": messages[1]["role"],
32
+ "content": B_SYS + messages[0]["content"] + E_SYS + messages[1]["content"],
33
+ }
34
+ ] + messages[2:]
35
+ assert all([msg["role"] == "user" for msg in messages[::2]]) and all(
36
+ [msg["role"] == "assistant" for msg in messages[1::2]]
37
+ ), (
38
+ "model only supports 'system', 'user' and 'assistant' roles, "
39
+ "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
40
+ )
41
+ messages_tokens: List[int] = sum(
42
+ [
43
+ tokenizer_encode(
44
+ f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
45
+ bos=True,
46
+ eos=True,
47
+ )
48
+ for prompt, answer in zip(
49
+ messages[::2],
50
+ messages[1::2],
51
+ )
52
+ ],
53
+ [],
54
+ )
55
+ assert messages[-1]["role"] == "user", f"Last message must be from user, got {messages[-1]['role']}"
56
+ messages_tokens += tokenizer_encode(
57
+ f"{B_INST} {(messages[-1]['content']).strip()} {E_INST}",
58
+ bos=True,
59
+ eos=False,
60
+ )
61
+ return messages_tokens
62
+
63
+
64
+ def _llama_cpp_tokenizer_encode(s: str, bos: bool, eos: bool, llm: Llama) -> List[int]:
65
+ assert type(s) is str
66
+ t = llm.tokenize(text=b" " + bytes(s, encoding="utf-8"), add_bos=False)
67
+ if bos:
68
+ t = [llm.token_bos()] + t
69
+ if eos:
70
+ t = t + [llm.token_eos()]
71
+ return t
72
+
73
+
74
+ class Llama2ChatCompletionWrapper:
75
+ def __init__(self, model_path: str, callback: Callable[[Message], None] = None, tokenizer_encoder: Callable = None) -> None:
76
+ self.llm = Llama(model_path=model_path)
77
+ if tokenizer_encoder is None:
78
+ self._tokenizer_encode = partial(_llama_cpp_tokenizer_encode, llm=self.llm)
79
+ else:
80
+ self._tokenizer_encode = tokenizer_encoder
81
+ self.callback = callback
82
+
83
+ def new_session(self, system_content: str | None = None, messages: List[Message] | None = None):
84
+ self.messages: List[Message] = []
85
+
86
+ # if self.callback is not None:
87
+ # self.callback()
88
+
89
+ if system_content is not None:
90
+ assert messages is None
91
+ self.messages.append(Message(role="system", content=system_content))
92
+ if self.callback is not None:
93
+ self.callback(self.messages[-1])
94
+
95
+ elif messages is not None:
96
+ self.messages = messages
97
+ if self.callback is not None:
98
+ for msg in self.messages:
99
+ self.callback(msg)
100
+
101
+ def __call__(
102
+ self, message: str, post_process: Callable[[str], str] | None = None, max_tokens: int = 128, params: dict = {}
103
+ ) -> str:
104
+ self.messages.append(Message(role="user", content=message))
105
+
106
+ if self.callback is not None:
107
+ self.callback(self.messages[-1])
108
+
109
+ messages_tokens = _llama2_format_messages(self.messages, tokenizer_encode=self._tokenizer_encode)
110
+
111
+ completion = self.llm.generate(messages_tokens, **params)
112
+ max_tokens = (
113
+ max_tokens if max_tokens + len(messages_tokens) < self.llm._n_ctx else (self.llm._n_ctx - len(messages_tokens))
114
+ )
115
+ result = []
116
+ for i, token in enumerate(completion):
117
+ if max_tokens == i or token == self.llm.token_eos():
118
+ break
119
+ result.append(self.llm.detokenize([token]).decode("utf-8"))
120
+
121
+ result = "".join(result).strip()
122
+
123
+ if post_process is not None:
124
+ # if self.callback is not None:
125
+ # self.callback()
126
+ result = post_process(result)
127
+
128
+ self.messages.append(Message(role="assistant", content=result))
129
+ if self.callback is not None:
130
+ self.callback(self.messages[-1])
131
+
132
+ return result
requirements.txt ADDED
Binary file (1.47 kB). View file
 
streamlit_app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from llama_cpp_chat_completion_wrapper import Llama2ChatCompletionWrapper, Message
3
+ import streamlit as st
4
+
5
+
6
+ def console_print(message: Message) -> None:
7
+ reset = "\033[00m"
8
+ color_map = {
9
+ "system": ("\033[1;35m", "\033[35m"),
10
+ "user": ("\033[1;33m", "\033[33m"),
11
+ "assistant": ("\033[1;31m", "\033[31m"),
12
+ "assistant-before-post-process": ("\033[1;31m", "\033[31m"),
13
+ }
14
+ role_color, content_color = color_map[message["role"]]
15
+ formatted_message = f"{role_color}{message['role'].upper()}{reset}> {content_color}{message['content']}{reset}"
16
+ print(formatted_message)
17
+
18
+
19
+ @st.cache_resource
20
+ def load_model():
21
+ model_path = os.path.join(os.path.dirname(__file__), "llama-2-7b-chat.ggmlv3.q2_K.bin")
22
+ return Llama2ChatCompletionWrapper(model_path=model_path, callback=console_print)
23
+
24
+
25
+ def main():
26
+ if len(st.session_state) == 0:
27
+ st.session_state["llm"] = load_model()
28
+ st.session_state["params"] = {
29
+ "temp": 0,
30
+ "top_p": 1,
31
+ "frequency_penalty": 0,
32
+ "presence_penalty": 0,
33
+ }
34
+ st.session_state["llm"].new_session()
35
+ st.session_state["messages"] = st.session_state["llm"].messages
36
+
37
+ st.title("💬 Chatbot")
38
+
39
+ for msg in st.session_state["messages"]:
40
+ st.chat_message(msg["role"]).write(msg["content"])
41
+
42
+ if prompt := st.chat_input():
43
+ st.chat_message("user").write(prompt)
44
+ response = st.session_state["llm"](prompt, params=st.session_state["params"])
45
+ st.chat_message("assistant").write(response)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ main()
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723