BetaAI_Chat / app.py
ほしゆめ
Upload app.py
5ee3947
raw
history blame
No virus
3.55 kB
import gradio as gr
import itertools
from transformers import AutoTokenizer
import ctranslate2
generator = ctranslate2.Generator("./ct2-model" )
tokenizer = AutoTokenizer.from_pretrained( "./models", use_fast=False)
static_prompt="""ユーザー: We will now start chatting. If spoken to in English, answer in English; if spoken to in Japanese, answer in Japanese. Please take a deep breath and calm down and have a conversation.
システム: I'll try to keep calm and have a conversation.
ユーザー: その調子で頑張ってください。
システム: 分かりました。
"""
system_prompt_tokens=tokenizer.convert_ids_to_tokens(tokenizer.encode(static_prompt, add_special_tokens=False))
def inference_func(prompt, max_length=128, sampling_topk=40, sampling_topp=0.75, sampling_temperature=0.7, repetition_penalty=1.4):
tokens = tokenizer.convert_ids_to_tokens( tokenizer.encode(prompt, add_special_tokens=False))
results = generator.generate_batch(
[tokens],
static_prompt=system_prompt_tokens,
max_length=max_length,
sampling_topk=sampling_topk,
sampling_topp=sampling_topp,
sampling_temperature=sampling_temperature,
repetition_penalty=repetition_penalty,
include_prompt_in_result=False,
)
output = tokenizer.decode(results[0].sequences_ids[0])
return output
def make_prompt(message, chat_history, max_context_size: int = 10):
contexts = chat_history + [[message, ""]]
contexts = list(itertools.chain.from_iterable(contexts))
if max_context_size > 0:
context_size = max_context_size - 1
else:
context_size = 100000
contexts = contexts[-context_size:]
prompt = []
for idx, context in enumerate(reversed(contexts)):
if idx % 2 == 0:
prompt = [f"システム: {context}"] + prompt
else:
prompt = [f"ユーザー: {context}"] + prompt
prompt = "\n".join(prompt)
return prompt
def interact_func(message, chat_history, max_context_size, max_length, sampling_topk, sampling_topp, sampling_temperature, repetition_penalty ):
prompt = make_prompt(message, chat_history, max_context_size)
print(f"prompt: {prompt}")
generated = inference_func(prompt, max_length, sampling_topk, sampling_topp, sampling_temperature, repetition_penalty )
print(f"generated: {generated}")
chat_history.append((message, generated))
return "", chat_history
with gr.Blocks( theme="monochrome" ) as demo:
with gr.Accordion("Parameters", open=False):
# max_context_size = the number of turns * 2
max_context_size = gr.Number(value=10, label="max_context_size", precision=0)
max_length = gr.Number(value=128, label="max_length", precision=0)
sampling_topk = gr.Slider(0, 1000, value=40, step=0.1, label="top_k")
sampling_topp = gr.Slider(0.1, 1.0, value=0.75, step=0.1, label="top_p")
sampling_temperature = gr.Slider(0.0, 10.0, value=0.7, step=0.1, label="temperature")
repetition_penalty = gr.Slider(0.0, 10.0, value=1.4, step=0.1, label="repetition_penalty")
chatbot = gr.Chatbot( show_copy_button=True, show_share_button="RETRY" )
msg = gr.Textbox()
clear = gr.Button("RESET")
msg.submit(
interact_func,
[msg, chatbot, max_context_size, max_length, sampling_topk, sampling_topp, sampling_temperature, repetition_penalty],
[msg, chatbot],
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.launch(debug=True)