Spaces:
Runtime error
Runtime error
File size: 6,281 Bytes
b9d2c44 b4bcaec b9d2c44 66295fd b9d2c44 8f6c2ca b9d2c44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
from typing import Any, List, Union, Mapping, Optional, Iterable
import ctranslate2
from ctranslate2 import GenerationStepResult, Generator
import transformers
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, AutoTokenizer
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
generator = ctranslate2.Generator("./FixedStar-BETA-7b-ct2")
tokenizer = AutoTokenizer.from_pretrained("./tokenizer", use_fast=True)
TITLE = "Chat room!"
class CTranslate2StreamLLM(LLM):
generator: Generator
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
max_length: int = 128
repetition_penalty: float = 1.1
temperature: float = 0.7
topk: int = 10
@property
def _llm_type(self) -> str:
return "CTranslate2"
def _generate_tokens(
self,
prompt: str,
) -> Iterable[GenerationStepResult]:
# 推論の実行
tokens = self.tokenizer.convert_ids_to_tokens(
self.tokenizer.encode(prompt, add_special_tokens=False)
)
step_results = self.generator.generate_tokens(
tokens,
max_length=self.max_length,
sampling_topk=self.topk,
sampling_temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
end_token=[26168, 27, 208, 14719, 9078, 18482, 27, 208],
)
return step_results
def _decode_with_buffer(
self, step_result: GenerationStepResult, token_buffer: list
) -> Union[str, None]:
token_buffer.append(step_result.token_id)
word = self.tokenizer.decode(token_buffer)
# 全て変換不能文字の場合、終了
if all(c == "�" for c in word):
return None
# step_resultのtokenが▁から始まる場合、スペースを付与する
if step_result.token.startswith("▁"):
word = " " + word
# 正常な文字が生成できた場合、バッファをクリア
token_buffer.clear()
return word
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
step_results = self._generate_tokens(prompt)
output_ids = []
token_buffer = []
for step_result in step_results:
output_ids.append(step_result.token_id)
if run_manager:
if word := self._decode_with_buffer(step_result, token_buffer):
run_manager.on_llm_new_token(
word,
verbose=self.verbose,
logprobs=step_result.log_prob if step_result.log_prob else None,
)
if output_ids:
text = self.tokenizer.decode(output_ids)
return text
return ""
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
step_results = self._generate_tokens(prompt)
output_ids = []
token_buffer = []
for step_result in step_results:
output_ids.append(step_result.token_id)
if run_manager:
if word := self._decode_with_buffer(step_result, token_buffer):
await run_manager.on_llm_new_token(
word,
verbose=self.verbose,
logprobs=step_result.log_prob if step_result.log_prob else None,
)
if output_ids:
text = self.tokenizer.decode(output_ids)
return text
return ""
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"generator": self.generator,
"tokenizer": self.tokenizer,
"max_length": self.max_length,
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"topk": self.topk,
}
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
llm = CTranslate2StreamLLM(
generator=generator,
tokenizer=tokenizer,
callbacks=[StreamingStdOutCallbackHandler()])
# ウェブUIの起動
import os
import itertools
import gradio as gr
def make_prompt(message, chat_history, max_context_size: int = 10):
contexts = chat_history + [[message, ""]]
contexts = list(itertools.chain.from_iterable(contexts))
if max_context_size > 0:
context_size = max_context_size - 1
else:
context_size = 100000
contexts = contexts[-context_size:]
prompt = []
for idx, context in enumerate(reversed(contexts)):
if idx % 2 == 0:
prompt = [f"ASSISTANT: {context}"] + prompt
else:
prompt = [f"USER: {context}"] + prompt
prompt = "\n".join(prompt)
return prompt
def interact_func(message, chat_history, max_context_size):
prompt = make_prompt(message, chat_history, max_context_size)
print(f"prompt: {prompt}")
generated = llm(prompt)
generated = generated.replace("\nUSER", "")
print(f"generated: {generated}")
chat_history.append((message, generated))
yield "", chat_history
with gr.Blocks(theme="monochrome") as demo:
gr.Markdown(TITLE)
with gr.Accordion("Configs", open=False):
# max_context_size = the number of turns * 2
max_context_size = gr.Number(value=20, label="記憶する会話ターン数", precision=0)
max_length = gr.Number(value=128, label="最大文字数", precision=0)
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("消す")
msg.submit(
interact_func,
[msg, chatbot, max_context_size],
[msg, chatbot],
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.queue()
demo.launch(debug=True, share=True) |