Yumenohoshi commited on
Commit
b183ad6
1 Parent(s): bfcc56b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -0
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ウェブUIの起動
2
+ import os
3
+ import itertools
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer
7
+ import ctranslate2
8
+ import gradio as gr
9
+
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ generator = ctranslate2.Generator("./FixedStar-BETA-7b-ct2", device=device)
12
+ tokenizer = AutoTokenizer.from_pretrained(
13
+ "Yumenohoshi/Fixedstar-BETA", use_fast=True)
14
+
15
+ def inference_func(prompt, max_length=128, sampling_temperature=0.7):
16
+ tokens = tokenizer.convert_ids_to_tokens(
17
+ tokenizer.encode(prompt, add_special_tokens=False)
18
+ )
19
+ results = generator.generate_batch(
20
+ [tokens],
21
+ max_length=max_length,
22
+ sampling_topk=20,
23
+ sampling_temperature=sampling_temperature,
24
+ include_prompt_in_result=False,
25
+ )
26
+ output = tokenizer.decode(results[0].sequences_ids[0])
27
+ return output
28
+
29
+ def make_prompt(message, chat_history, max_context_size: int = 10):
30
+ contexts = chat_history + [[message, ""]]
31
+ contexts = list(itertools.chain.from_iterable(contexts))
32
+ if max_context_size > 0:
33
+ context_size = max_context_size - 1
34
+ else:
35
+ context_size = 100000
36
+ contexts = contexts[-context_size:]
37
+ prompt = []
38
+ for idx, context in enumerate(reversed(contexts)):
39
+ if idx % 2 == 0:
40
+ prompt = [f"ASSISTANT: {context}"] + prompt
41
+ else:
42
+ prompt = [f"USER: {context}"] + prompt
43
+ prompt = "\n".join(prompt)
44
+ return prompt
45
+
46
+
47
+ def interact_func(message, chat_history, max_context_size, max_length, sampling_temperature):
48
+ prompt = make_prompt(message, chat_history, max_context_size)
49
+ print(f"prompt: {prompt}")
50
+ generated = inference_func(prompt, max_length, sampling_temperature)
51
+ print(f"generated: {generated}")
52
+ chat_history.append((message, generated))
53
+ return "", chat_history
54
+
55
+
56
+ with gr.Blocks(theme="monochrome") as demo:
57
+ with gr.Accordion("Configs", open=False):
58
+ # max_context_size = the number of turns * 2
59
+ max_context_size = gr.Number(value=20, label="記憶する会話ターン数", precision=0)
60
+ max_length = gr.Number(value=128, label="最大文字数", precision=0)
61
+ sampling_temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.1, label="創造性")
62
+ chatbot = gr.Chatbot()
63
+ msg = gr.Textbox()
64
+ clear = gr.Button("消す")
65
+ msg.submit(
66
+ interact_func,
67
+ [msg, chatbot, max_context_size, max_length, sampling_temperature],
68
+ [msg, chatbot],
69
+ )
70
+ clear.click(lambda: None, None, chatbot, queue=False)
71
+
72
+ if __name__ == "__main__":
73
+ demo.launch(debug=True, share=True)