from typing import Any, List, Union, Mapping, Optional, Iterable import ctranslate2 from ctranslate2 import GenerationStepResult, Generator import transformers from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast, AutoTokenizer from langchain.llms.base import LLM from langchain.callbacks.manager import CallbackManagerForLLMRun generator = ctranslate2.Generator("./FixedStar-BETA-7b-ct2") tokenizer = AutoTokenizer.from_pretrained("./tokenizer", use_fast=True) TITLE = "Chat room!" class CTranslate2StreamLLM(LLM): generator: Generator tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] max_length: int = 128 repetition_penalty: float = 1.1 temperature: float = 0.7 topk: int = 10 @property def _llm_type(self) -> str: return "CTranslate2" def _generate_tokens( self, prompt: str, ) -> Iterable[GenerationStepResult]: # 推論の実行 tokens = self.tokenizer.convert_ids_to_tokens( self.tokenizer.encode(prompt, add_special_tokens=False) ) step_results = self.generator.generate_tokens( tokens, max_length=self.max_length, sampling_topk=self.topk, sampling_temperature=self.temperature, repetition_penalty=self.repetition_penalty, end_token=[26168, 27, 208, 14719, 9078, 18482, 27, 208], ) return step_results def _decode_with_buffer( self, step_result: GenerationStepResult, token_buffer: list ) -> Union[str, None]: token_buffer.append(step_result.token_id) word = self.tokenizer.decode(token_buffer) # 全て変換不能文字の場合、終了 if all(c == "�" for c in word): return None # step_resultのtokenが▁から始まる場合、スペースを付与する if step_result.token.startswith("▁"): word = " " + word # 正常な文字が生成できた場合、バッファをクリア token_buffer.clear() return word def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: if stop is not None: raise ValueError("stop kwargs are not permitted.") step_results = self._generate_tokens(prompt) output_ids = [] token_buffer = [] for step_result in step_results: output_ids.append(step_result.token_id) if run_manager: if word := self._decode_with_buffer(step_result, token_buffer): run_manager.on_llm_new_token( word, verbose=self.verbose, logprobs=step_result.log_prob if step_result.log_prob else None, ) if output_ids: text = self.tokenizer.decode(output_ids) return text return "" async def _acall( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, ) -> str: if stop is not None: raise ValueError("stop kwargs are not permitted.") step_results = self._generate_tokens(prompt) output_ids = [] token_buffer = [] for step_result in step_results: output_ids.append(step_result.token_id) if run_manager: if word := self._decode_with_buffer(step_result, token_buffer): await run_manager.on_llm_new_token( word, verbose=self.verbose, logprobs=step_result.log_prob if step_result.log_prob else None, ) if output_ids: text = self.tokenizer.decode(output_ids) return text return "" @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return { "generator": self.generator, "tokenizer": self.tokenizer, "max_length": self.max_length, "repetition_penalty": self.repetition_penalty, "temperature": self.temperature, "topk": self.topk, } from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler llm = CTranslate2StreamLLM( generator=generator, tokenizer=tokenizer, callbacks=[StreamingStdOutCallbackHandler()]) # ウェブUIの起動 import os import itertools import gradio as gr def make_prompt(message, chat_history, max_context_size: int = 10): contexts = chat_history + [[message, ""]] contexts = list(itertools.chain.from_iterable(contexts)) if max_context_size > 0: context_size = max_context_size - 1 else: context_size = 100000 contexts = contexts[-context_size:] prompt = [] for idx, context in enumerate(reversed(contexts)): if idx % 2 == 0: prompt = [f"ASSISTANT: {context}"] + prompt else: prompt = [f"USER: {context}"] + prompt prompt = "\n".join(prompt) return prompt def interact_func(message, chat_history, max_context_size): prompt = make_prompt(message, chat_history, max_context_size) print(f"prompt: {prompt}") generated = llm(prompt) generated = generated.replace("\nUSER", "") print(f"generated: {generated}") chat_history.append((message, generated)) yield "", chat_history with gr.Blocks(theme="monochrome") as demo: gr.Markdown(TITLE) with gr.Accordion("Configs", open=False): # max_context_size = the number of turns * 2 max_context_size = gr.Number(value=20, label="記憶する会話ターン数", precision=0) max_length = gr.Number(value=128, label="最大文字数", precision=0) chatbot = gr.Chatbot() msg = gr.Textbox() clear = gr.Button("消す") msg.submit( interact_func, [msg, chatbot, max_context_size], [msg, chatbot], ) clear.click(lambda: None, None, chatbot, queue=False) if __name__ == "__main__": demo.queue() demo.launch(debug=True, share=True)