# ウェブUIの起動 import os import itertools import torch from transformers import AutoTokenizer import ctranslate2 import gradio as gr DESCRIPTION=""" ## 概要 - これは、とある研究発表のために作られたチャットルーム(スペース)です。アクセス過多の場合は、少し時間をおいてから再度アクセスしてください。 - 詳細設定にて、AIが生成する文章のテイストを調整することが出来ます。 - AIの名前は「ベータ」です。 ## たのむぞ - あまり個人情報を入力しないでください。 - 会話内容は収集しておりません。 """ device = "cuda" if torch.cuda.is_available() else "cpu" generator = ctranslate2.Generator("./FixedStar-BETA-7b-ct2", device=device) tokenizer = AutoTokenizer.from_pretrained( "./tokenizer", use_fast=True) def inference_func(prompt, max_length=64, sampling_temperature=0.7): tokens = tokenizer.convert_ids_to_tokens( tokenizer.encode(prompt, add_special_tokens=False) ) results = generator.generate_batch( [tokens], max_length=max_length, sampling_topk=20, sampling_temperature=sampling_temperature, repetition_penalty=1.1, end_token=[26168, 27, 208, 14719, 9078, 18482, 27, 208], include_prompt_in_result=False, ) output = tokenizer.decode(results[0].sequences_ids[0]) return output def make_prompt(message, chat_history, max_context_size: int = 10): contexts = chat_history + [[message, ""]] contexts = list(itertools.chain.from_iterable(contexts)) if max_context_size > 0: context_size = max_context_size - 1 else: context_size = 100000 contexts = contexts[-context_size:] prompt = [] for idx, context in enumerate(reversed(contexts)): if idx % 2 == 0: prompt = [f"ASSISTANT: {context}"] + prompt else: prompt = [f"USER: {context}"] + prompt prompt = "\n".join(prompt) return prompt def interact_func(message, chat_history, max_context_size, max_length, sampling_temperature): prompt = make_prompt(message, chat_history, max_context_size) print(f"prompt: {prompt}") generated = inference_func(prompt, max_length, sampling_temperature) print(f"generated: {generated}") chat_history.append((message, generated)) return "", chat_history with gr.Blocks(theme="monochrome") as demo: with gr.Accordion("Configs", open=False): # max_context_size = the number of turns * 2 max_context_size = gr.Number(value=20, label="記憶する会話ターン数", precision=0) max_length = gr.Number(value=64, label="最大文字数", precision=0) sampling_temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="創造性") chatbot = gr.Chatbot() msg = gr.Textbox() clear = gr.Button("消す") msg.submit( interact_func, [msg, chatbot, max_context_size, max_length, sampling_temperature], [msg, chatbot], ) clear.click(lambda: None, None, chatbot, queue=False) gr.Markdown(DESCRIPTION) if __name__ == "__main__": demo.launch(debug=True, share=True)