ほしゆめ commited on
Commit
5ee3947
1 Parent(s): f6d7a62

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import itertools
3
+ from transformers import AutoTokenizer
4
+ import ctranslate2
5
+
6
+ generator = ctranslate2.Generator("./ct2-model" )
7
+ tokenizer = AutoTokenizer.from_pretrained( "./models", use_fast=False)
8
+
9
+ static_prompt="""ユーザー: We will now start chatting. If spoken to in English, answer in English; if spoken to in Japanese, answer in Japanese. Please take a deep breath and calm down and have a conversation.
10
+ システム: I'll try to keep calm and have a conversation.
11
+ ユーザー: その調子で頑張ってください。
12
+ システム: 分かりました。
13
+ """
14
+ system_prompt_tokens=tokenizer.convert_ids_to_tokens(tokenizer.encode(static_prompt, add_special_tokens=False))
15
+
16
+ def inference_func(prompt, max_length=128, sampling_topk=40, sampling_topp=0.75, sampling_temperature=0.7, repetition_penalty=1.4):
17
+ tokens = tokenizer.convert_ids_to_tokens( tokenizer.encode(prompt, add_special_tokens=False))
18
+ results = generator.generate_batch(
19
+ [tokens],
20
+ static_prompt=system_prompt_tokens,
21
+ max_length=max_length,
22
+ sampling_topk=sampling_topk,
23
+ sampling_topp=sampling_topp,
24
+ sampling_temperature=sampling_temperature,
25
+ repetition_penalty=repetition_penalty,
26
+ include_prompt_in_result=False,
27
+ )
28
+ output = tokenizer.decode(results[0].sequences_ids[0])
29
+ return output
30
+
31
+ def make_prompt(message, chat_history, max_context_size: int = 10):
32
+ contexts = chat_history + [[message, ""]]
33
+ contexts = list(itertools.chain.from_iterable(contexts))
34
+ if max_context_size > 0:
35
+ context_size = max_context_size - 1
36
+ else:
37
+ context_size = 100000
38
+ contexts = contexts[-context_size:]
39
+ prompt = []
40
+ for idx, context in enumerate(reversed(contexts)):
41
+ if idx % 2 == 0:
42
+ prompt = [f"システム: {context}"] + prompt
43
+ else:
44
+ prompt = [f"ユーザー: {context}"] + prompt
45
+ prompt = "\n".join(prompt)
46
+ return prompt
47
+
48
+ def interact_func(message, chat_history, max_context_size, max_length, sampling_topk, sampling_topp, sampling_temperature, repetition_penalty ):
49
+ prompt = make_prompt(message, chat_history, max_context_size)
50
+ print(f"prompt: {prompt}")
51
+ generated = inference_func(prompt, max_length, sampling_topk, sampling_topp, sampling_temperature, repetition_penalty )
52
+ print(f"generated: {generated}")
53
+ chat_history.append((message, generated))
54
+ return "", chat_history
55
+
56
+ with gr.Blocks( theme="monochrome" ) as demo:
57
+ with gr.Accordion("Parameters", open=False):
58
+ # max_context_size = the number of turns * 2
59
+ max_context_size = gr.Number(value=10, label="max_context_size", precision=0)
60
+ max_length = gr.Number(value=128, label="max_length", precision=0)
61
+ sampling_topk = gr.Slider(0, 1000, value=40, step=0.1, label="top_k")
62
+ sampling_topp = gr.Slider(0.1, 1.0, value=0.75, step=0.1, label="top_p")
63
+ sampling_temperature = gr.Slider(0.0, 10.0, value=0.7, step=0.1, label="temperature")
64
+ repetition_penalty = gr.Slider(0.0, 10.0, value=1.4, step=0.1, label="repetition_penalty")
65
+ chatbot = gr.Chatbot( show_copy_button=True, show_share_button="RETRY" )
66
+ msg = gr.Textbox()
67
+ clear = gr.Button("RESET")
68
+ msg.submit(
69
+ interact_func,
70
+ [msg, chatbot, max_context_size, max_length, sampling_topk, sampling_topp, sampling_temperature, repetition_penalty],
71
+ [msg, chatbot],
72
+ )
73
+ clear.click(lambda: None, None, chatbot, queue=False)
74
+
75
+ if __name__ == "__main__":
76
+ demo.launch(debug=True)