ほしゆめ commited on
Commit
478dc1c
1 Parent(s): 61da86f

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ from transformers import AutoTokenizer
3
+ import ctranslate2
4
+ import gradio as gr
5
+
6
+ generator = ctranslate2.Generator("FixedStar-DebugChat/models/ct2-model" )
7
+ tokenizer = AutoTokenizer.from_pretrained( "models", use_fast=False)
8
+
9
+ static_prompt="""ユーザー: We will now start chatting. If spoken to in English, answer in English; if spoken to in Japanese, answer in Japanese. Please take a deep breath and calm down and have a conversation.
10
+ システム: I'll try to keep calm and have a conversation.
11
+ ユーザー: その調子で頑張ってください。
12
+ システム: 分かりました。
13
+ """
14
+ system_prompt_tokens=tokenizer.convert_ids_to_tokens(tokenizer.encode(static_prompt, add_special_tokens=False))
15
+
16
+ def inference_func(prompt, max_length=128, sampling_topk=40, sampling_topp=0.75, sampling_temperature=0.7, repetition_penalty=1.4):
17
+ tokens = tokenizer.convert_ids_to_tokens( tokenizer.encode(prompt, add_special_tokens=False))
18
+ results = generator.generate_batch(
19
+ [tokens],
20
+ static_prompt=system_prompt_tokens,
21
+ max_length=max_length,
22
+ sampling_topk=sampling_topk,
23
+ sampling_topp=sampling_topp,
24
+ sampling_temperature=sampling_temperature,
25
+ repetition_penalty=repetition_penalty,
26
+ include_prompt_in_result=False,
27
+ )
28
+ output = tokenizer.decode(results[0].sequences_ids[0])
29
+ return output
30
+
31
+ def make_prompt(message, chat_history, max_context_size: int = 10):
32
+ contexts = chat_history + [[message, ""]]
33
+ contexts = list(itertools.chain.from_iterable(contexts))
34
+ if max_context_size > 0:
35
+ context_size = max_context_size - 1
36
+ else:
37
+ context_size = 100000
38
+ contexts = contexts[-context_size:]
39
+ prompt = []
40
+ for idx, context in enumerate(reversed(contexts)):
41
+ if idx % 2 == 0:
42
+ prompt = [f"システム: {context}"] + prompt
43
+ else:
44
+ prompt = [f"ユーザー: {context}"] + prompt
45
+ prompt = "\n".join(prompt)
46
+ return prompt
47
+
48
+
49
+ def interact_func(message, chat_history, max_context_size, max_length, sampling_topk, sampling_topp, sampling_temperature, repetition_penalty ):
50
+ prompt = make_prompt(message, chat_history, max_context_size)
51
+ print(f"prompt: {prompt}")
52
+ generated = inference_func(prompt, max_length, sampling_topk, sampling_topp, sampling_temperature, repetition_penalty )
53
+ print(f"generated: {generated}")
54
+ chat_history.append((message, generated))
55
+ return "", chat_history
56
+
57
+ with gr.Blocks( theme="monochrome" ) as demo:
58
+ with gr.Accordion("Parameters", open=False):
59
+ # max_context_size = the number of turns * 2
60
+ max_context_size = gr.Number(value=10, label="max_context_size", precision=0)
61
+ max_length = gr.Number(value=128, label="max_length", precision=0)
62
+ sampling_topk = gr.Slider(0, 1000, value=40, step=0.1, label="top_k")
63
+ sampling_topp = gr.Slider(0.1, 1.0, value=0.75, step=0.1, label="top_p")
64
+ sampling_temperature = gr.Slider(0.0, 10.0, value=0.7, step=0.1, label="temperature")
65
+ repetition_penalty = gr.Slider(0.0, 10.0, value=1.4, step=0.1, label="repetition_penalty")
66
+ chatbot = gr.Chatbot( show_copy_button=True, show_share_button="RETRY", avatar_images=["icon.png", "user.png"] )
67
+ msg = gr.Textbox()
68
+ clear = gr.Button("RESET")
69
+ msg.submit(
70
+ interact_func,
71
+ [msg, chatbot, max_context_size, max_length, sampling_topk, sampling_topp, sampling_temperature, repetition_penalty],
72
+ [msg, chatbot],
73
+ )
74
+ clear.click(lambda: None, None, chatbot, queue=False)
75
+
76
+ if __name__ == "__main__":
77
+ demo.launch()