Yumenohoshi commited on
Commit
b9d2c44
1 Parent(s): c402338

Upload app2.py

Browse files
Files changed (1) hide show
  1. app2.py +209 -0
app2.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Union, Mapping, Optional, Iterable
2
+
3
+ import ctranslate2
4
+ from ctranslate2 import GenerationStepResult, Generator
5
+ import transformers
6
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, AutoTokenizer
7
+
8
+ from langchain.llms.base import LLM
9
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
10
+
11
+ generator=Generator("E:\datafile\AI\FixedStar-model\FixedStar-BETA-7b-ct2")
12
+ tokenizer=AutoTokenizer.from_pretrained("E:\datafile\AI\FixedStar-model\FixedStar-BETA-7b")
13
+ TITLE = "Chat room!"
14
+
15
+ class CTranslate2StreamLLM(LLM):
16
+
17
+ generator: Generator
18
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
19
+
20
+ max_length: int = 128
21
+ repetition_penalty: float = 1.1
22
+ temperature: float = 0.6
23
+ topk: int = 1
24
+
25
+ @property
26
+ def _llm_type(self) -> str:
27
+ return "CTranslate2"
28
+
29
+ def _generate_tokens(
30
+ self,
31
+ prompt: str,
32
+ ) -> Iterable[GenerationStepResult]:
33
+
34
+ # 推論の実行
35
+ tokens = self.tokenizer.convert_ids_to_tokens(
36
+ self.tokenizer.encode(prompt, add_special_tokens=False)
37
+ )
38
+
39
+ step_results = self.generator.generate_tokens(
40
+ tokens,
41
+ max_length=self.max_length,
42
+ sampling_topk=self.topk,
43
+ sampling_temperature=self.temperature,
44
+ repetition_penalty=self.repetition_penalty,
45
+ return_log_prob=True,
46
+ end_token=[26168, 27, 208, 14719, 9078, 18482, 27, 208],
47
+ )
48
+
49
+ return step_results
50
+
51
+ def _decode_with_buffer(
52
+ self, step_result: GenerationStepResult, token_buffer: list
53
+ ) -> Union[str, None]:
54
+ token_buffer.append(step_result.token_id)
55
+ word = self.tokenizer.decode(token_buffer)
56
+
57
+ # 全て変換不能文字の場合、終了
58
+ if all(c == "�" for c in word):
59
+ return None
60
+
61
+ # step_resultのtokenが▁から始まる場合、スペースを付与する
62
+ if step_result.token.startswith("▁"):
63
+ word = " " + word
64
+
65
+ # 正常な文字が生成できた場合、バッファをクリア
66
+ token_buffer.clear()
67
+
68
+ return word
69
+
70
+ def _call(
71
+ self,
72
+ prompt: str,
73
+ stop: Optional[List[str]] = None,
74
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
75
+ ) -> str:
76
+ if stop is not None:
77
+ raise ValueError("stop kwargs are not permitted.")
78
+
79
+ step_results = self._generate_tokens(prompt)
80
+
81
+ output_ids = []
82
+ token_buffer = []
83
+
84
+ for step_result in step_results:
85
+
86
+ output_ids.append(step_result.token_id)
87
+
88
+ if run_manager:
89
+ if word := self._decode_with_buffer(step_result, token_buffer):
90
+
91
+ run_manager.on_llm_new_token(
92
+ word,
93
+ verbose=self.verbose,
94
+ logprobs=step_result.log_prob if step_result.log_prob else None,
95
+ )
96
+
97
+ if output_ids:
98
+ text = self.tokenizer.decode(output_ids)
99
+ return text
100
+
101
+ return ""
102
+
103
+ async def _acall(
104
+ self,
105
+ prompt: str,
106
+ stop: Optional[List[str]] = None,
107
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
108
+ ) -> str:
109
+ if stop is not None:
110
+ raise ValueError("stop kwargs are not permitted.")
111
+
112
+ step_results = self._generate_tokens(prompt)
113
+
114
+ output_ids = []
115
+ token_buffer = []
116
+
117
+ for step_result in step_results:
118
+
119
+ output_ids.append(step_result.token_id)
120
+
121
+ if run_manager:
122
+ if word := self._decode_with_buffer(step_result, token_buffer):
123
+
124
+ await run_manager.on_llm_new_token(
125
+ word,
126
+ verbose=self.verbose,
127
+ logprobs=step_result.log_prob if step_result.log_prob else None,
128
+ )
129
+
130
+ if output_ids:
131
+ text = self.tokenizer.decode(output_ids)
132
+ return text
133
+
134
+ return ""
135
+
136
+ @property
137
+ def _identifying_params(self) -> Mapping[str, Any]:
138
+ """Get the identifying parameters."""
139
+ return {
140
+ "generator": self.generator,
141
+ "tokenizer": self.tokenizer,
142
+ "max_length": self.max_length,
143
+ "repetition_penalty": self.repetition_penalty,
144
+ "temperature": self.temperature,
145
+ "topk": self.topk,
146
+ }
147
+
148
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
149
+ llm = CTranslate2StreamLLM(
150
+ generator=generator,
151
+ tokenizer=tokenizer,
152
+ callbacks=[StreamingStdOutCallbackHandler()])
153
+
154
+
155
+ # ウェブUIの起動
156
+ import os
157
+ import itertools
158
+ import gradio as gr
159
+
160
+ def make_prompt(message, chat_history, max_context_size: int = 10):
161
+ contexts = chat_history + [[message, ""]]
162
+ contexts = list(itertools.chain.from_iterable(contexts))
163
+ if max_context_size > 0:
164
+ context_size = max_context_size - 1
165
+ else:
166
+ context_size = 100000
167
+ contexts = contexts[-context_size:]
168
+ prompt = []
169
+ for idx, context in enumerate(reversed(contexts)):
170
+ if idx % 2 == 0:
171
+ prompt = [f"ASSISTANT: {context}"] + prompt
172
+ else:
173
+ prompt = [f"USER: {context}"] + prompt
174
+ prompt = "\n".join(prompt)
175
+ return prompt
176
+
177
+
178
+ def interact_func(message, chat_history, max_context_size):
179
+ prompt = make_prompt(message, chat_history, max_context_size)
180
+ print(f"prompt: {prompt}")
181
+ generated = llm(prompt)
182
+ generated = generated.replace("\nUSER", "")
183
+ print(f"generated: {generated}")
184
+ chat_history.append((message, generated))
185
+ yield "", chat_history
186
+
187
+
188
+ with gr.Blocks(theme="monochrome") as demo:
189
+ gr.Markdown(TITLE)
190
+ with gr.Accordion("Configs", open=False):
191
+ # max_context_size = the number of turns * 2
192
+ max_context_size = gr.Number(value=20, label="記憶する会話ターン数", precision=0)
193
+ max_length = gr.Number(value=128, label="最大文字数", precision=0)
194
+ chatbot = gr.Chatbot()
195
+ msg = gr.Textbox()
196
+ clear = gr.Button("消す")
197
+ msg.submit(
198
+ make_prompt,
199
+ [msg, chatbot, max_context_size],
200
+ [msg, chatbot],
201
+ queue=False
202
+ ).then(
203
+ interact_func, chatbot, chatbot
204
+ )
205
+ clear.click(lambda: None, None, chatbot, queue=False)
206
+
207
+ if __name__ == "__main__":
208
+ demo.queue()
209
+ demo.launch(debug=True, share=True)