BetaAI_Chat / app.py
Yumenohoshi's picture
Update app.py
66295fd verified
raw
history blame contribute delete
No virus
6.28 kB
from typing import Any, List, Union, Mapping, Optional, Iterable
import ctranslate2
from ctranslate2 import GenerationStepResult, Generator
import transformers
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, AutoTokenizer
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
generator = ctranslate2.Generator("./FixedStar-BETA-7b-ct2")
tokenizer = AutoTokenizer.from_pretrained("./tokenizer", use_fast=True)
TITLE = "Chat room!"
class CTranslate2StreamLLM(LLM):
generator: Generator
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
max_length: int = 128
repetition_penalty: float = 1.1
temperature: float = 0.7
topk: int = 10
@property
def _llm_type(self) -> str:
return "CTranslate2"
def _generate_tokens(
self,
prompt: str,
) -> Iterable[GenerationStepResult]:
# 推論の実行
tokens = self.tokenizer.convert_ids_to_tokens(
self.tokenizer.encode(prompt, add_special_tokens=False)
)
step_results = self.generator.generate_tokens(
tokens,
max_length=self.max_length,
sampling_topk=self.topk,
sampling_temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
end_token=[26168, 27, 208, 14719, 9078, 18482, 27, 208],
)
return step_results
def _decode_with_buffer(
self, step_result: GenerationStepResult, token_buffer: list
) -> Union[str, None]:
token_buffer.append(step_result.token_id)
word = self.tokenizer.decode(token_buffer)
# 全て変換不能文字の場合、終了
if all(c == "�" for c in word):
return None
# step_resultのtokenが▁から始まる場合、スペースを付与する
if step_result.token.startswith("▁"):
word = " " + word
# 正常な文字が生成できた場合、バッファをクリア
token_buffer.clear()
return word
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
step_results = self._generate_tokens(prompt)
output_ids = []
token_buffer = []
for step_result in step_results:
output_ids.append(step_result.token_id)
if run_manager:
if word := self._decode_with_buffer(step_result, token_buffer):
run_manager.on_llm_new_token(
word,
verbose=self.verbose,
logprobs=step_result.log_prob if step_result.log_prob else None,
)
if output_ids:
text = self.tokenizer.decode(output_ids)
return text
return ""
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
step_results = self._generate_tokens(prompt)
output_ids = []
token_buffer = []
for step_result in step_results:
output_ids.append(step_result.token_id)
if run_manager:
if word := self._decode_with_buffer(step_result, token_buffer):
await run_manager.on_llm_new_token(
word,
verbose=self.verbose,
logprobs=step_result.log_prob if step_result.log_prob else None,
)
if output_ids:
text = self.tokenizer.decode(output_ids)
return text
return ""
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"generator": self.generator,
"tokenizer": self.tokenizer,
"max_length": self.max_length,
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"topk": self.topk,
}
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
llm = CTranslate2StreamLLM(
generator=generator,
tokenizer=tokenizer,
callbacks=[StreamingStdOutCallbackHandler()])
# ウェブUIの起動
import os
import itertools
import gradio as gr
def make_prompt(message, chat_history, max_context_size: int = 10):
contexts = chat_history + [[message, ""]]
contexts = list(itertools.chain.from_iterable(contexts))
if max_context_size > 0:
context_size = max_context_size - 1
else:
context_size = 100000
contexts = contexts[-context_size:]
prompt = []
for idx, context in enumerate(reversed(contexts)):
if idx % 2 == 0:
prompt = [f"ASSISTANT: {context}"] + prompt
else:
prompt = [f"USER: {context}"] + prompt
prompt = "\n".join(prompt)
return prompt
def interact_func(message, chat_history, max_context_size):
prompt = make_prompt(message, chat_history, max_context_size)
print(f"prompt: {prompt}")
generated = llm(prompt)
generated = generated.replace("\nUSER", "")
print(f"generated: {generated}")
chat_history.append((message, generated))
yield "", chat_history
with gr.Blocks(theme="monochrome") as demo:
gr.Markdown(TITLE)
with gr.Accordion("Configs", open=False):
# max_context_size = the number of turns * 2
max_context_size = gr.Number(value=20, label="記憶する会話ターン数", precision=0)
max_length = gr.Number(value=128, label="最大文字数", precision=0)
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("消す")
msg.submit(
interact_func,
[msg, chatbot, max_context_size],
[msg, chatbot],
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.queue()
demo.launch(debug=True, share=True)