import gradio as gr import itertools from transformers import AutoTokenizer import ctranslate2 generator = ctranslate2.Generator("./ct2-model" ) tokenizer = AutoTokenizer.from_pretrained( "./models", use_fast=False) static_prompt="""ユーザー: We will now start chatting. If spoken to in English, answer in English; if spoken to in Japanese, answer in Japanese. Please take a deep breath and calm down and have a conversation. システム: I'll try to keep calm and have a conversation. ユーザー: その調子で頑張ってください。 システム: 分かりました。 """ system_prompt_tokens=tokenizer.convert_ids_to_tokens(tokenizer.encode(static_prompt, add_special_tokens=False)) def inference_func(prompt, max_length=128, sampling_topk=40, sampling_topp=0.75, sampling_temperature=0.7, repetition_penalty=1.4): tokens = tokenizer.convert_ids_to_tokens( tokenizer.encode(prompt, add_special_tokens=False)) results = generator.generate_batch( [tokens], static_prompt=system_prompt_tokens, max_length=max_length, sampling_topk=sampling_topk, sampling_topp=sampling_topp, sampling_temperature=sampling_temperature, repetition_penalty=repetition_penalty, include_prompt_in_result=False, ) output = tokenizer.decode(results[0].sequences_ids[0]) return output def make_prompt(message, chat_history, max_context_size: int = 10): contexts = chat_history + [[message, ""]] contexts = list(itertools.chain.from_iterable(contexts)) if max_context_size > 0: context_size = max_context_size - 1 else: context_size = 100000 contexts = contexts[-context_size:] prompt = [] for idx, context in enumerate(reversed(contexts)): if idx % 2 == 0: prompt = [f"システム: {context}"] + prompt else: prompt = [f"ユーザー: {context}"] + prompt prompt = "\n".join(prompt) return prompt def interact_func(message, chat_history, max_context_size, max_length, sampling_topk, sampling_topp, sampling_temperature, repetition_penalty ): prompt = make_prompt(message, chat_history, max_context_size) print(f"prompt: {prompt}") generated = inference_func(prompt, max_length, sampling_topk, sampling_topp, sampling_temperature, repetition_penalty ) print(f"generated: {generated}") chat_history.append((message, generated)) return "", chat_history with gr.Blocks( theme="monochrome" ) as demo: with gr.Accordion("Parameters", open=False): # max_context_size = the number of turns * 2 max_context_size = gr.Number(value=10, label="max_context_size", precision=0) max_length = gr.Number(value=128, label="max_length", precision=0) sampling_topk = gr.Slider(0, 1000, value=40, step=0.1, label="top_k") sampling_topp = gr.Slider(0.1, 1.0, value=0.75, step=0.1, label="top_p") sampling_temperature = gr.Slider(0.0, 10.0, value=0.7, step=0.1, label="temperature") repetition_penalty = gr.Slider(0.0, 10.0, value=1.4, step=0.1, label="repetition_penalty") chatbot = gr.Chatbot( show_copy_button=True, show_share_button="RETRY" ) msg = gr.Textbox() clear = gr.Button("RESET") msg.submit( interact_func, [msg, chatbot, max_context_size, max_length, sampling_topk, sampling_topp, sampling_temperature, repetition_penalty], [msg, chatbot], ) clear.click(lambda: None, None, chatbot, queue=False) if __name__ == "__main__": demo.launch(debug=True)