Spaces:
Runtime error
Runtime error
ほしゆめ
commited on
Commit
•
0ff13f4
1
Parent(s):
b00a232
Upload 3 files
Browse files- ChatSystem.py +53 -0
- FixedStar-icon.png +0 -0
- user-icon.png +0 -0
ChatSystem.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import gradio as gr
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
+
|
5 |
+
class RinnaTalk():
|
6 |
+
def __init__(self, tokenizer=None, model=None):
|
7 |
+
self.prompt = ''
|
8 |
+
# 事前にmodelとtokenizerを呼んでおく
|
9 |
+
self.tokenizer = AutoTokenizer.from_pretrained("Hoshiyume/FixedStar-BETA", use_fast=False) if tokenizer is None else tokenizer
|
10 |
+
self.model = AutoModelForCausalLM.from_pretrained("Hoshiyume/FixedStar-BETA", torch_dtype=torch.float16) if model is None else model
|
11 |
+
|
12 |
+
def chat(self, message: str, chat_history: list, max_token_length: int = 128, min_token_length: int = 10, temperature: float = 0.8):
|
13 |
+
# チャット履歴をクリアした際にpromptもクリアさせるため
|
14 |
+
if len(chat_history) == 0:
|
15 |
+
self.prompt = ''
|
16 |
+
|
17 |
+
self.prompt += f'ユーザー: {message}\nシステム: '
|
18 |
+
token_ids = self.tokenizer.encode(self.prompt, add_special_tokens=False, return_tensors="pt")
|
19 |
+
with torch.no_grad():
|
20 |
+
output_ids = self.model.generate(
|
21 |
+
token_ids,
|
22 |
+
max_new_tokens=max_token_length,
|
23 |
+
min_new_tokens=min_token_length,
|
24 |
+
top_p=top_p,
|
25 |
+
top_k=top_k,
|
26 |
+
do_sample=do_sample,
|
27 |
+
temperature=temperature,
|
28 |
+
num_beam=num_beam,
|
29 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
30 |
+
bos_token_id=self.tokenizer.bos_token_id,
|
31 |
+
eos_token_id=self.tokenizer.eos_token_id
|
32 |
+
)
|
33 |
+
output = self.tokenizer.decode(output_ids.tolist()[0])
|
34 |
+
latest_reply = output.split('<NL>')[-1].rstrip('</s>')
|
35 |
+
chat_history.append([message, latest_reply])
|
36 |
+
self.prompt += f'{latest_reply}\n'
|
37 |
+
|
38 |
+
return "", chat_history
|
39 |
+
rinna = RinnaTalk()
|
40 |
+
|
41 |
+
with gr.Blocks() as demo:
|
42 |
+
chatbot = gr.Chatbot( label="FixedStar-DebugChat", show_copy_button=True, show_share_button=True, avatar_images=["user-icon.png", "FixedStar-icon.png"] )
|
43 |
+
max_token_length = gr.Slider( value=512, minimum=10, maximum=512, label='max_token_length' )
|
44 |
+
min_token_length = gr.Slider( value=1, minimum=1, maximum=512, label='min_token_length' )
|
45 |
+
top_p = gr.Slider( value=0.75, minimum=0, maximum=1, label='top_p' )
|
46 |
+
top_k = gr.Slider( value=40, minimum=1, maximum=1000, label='top_k' )
|
47 |
+
temperature = gr.Slider( value=0.9, minimum=0, maximum=1, scale=0.01, value=0.8, label='temperature' )
|
48 |
+
do_sample = gr.Checkbox( value=True, label='do_sample' )
|
49 |
+
num_beam = gr.Slider( value=0, minimum=0, maximum=100, label='num_beam' )
|
50 |
+
msg = gr.Textbox()
|
51 |
+
clear = gr.ClearButton([msg, chatbot])
|
52 |
+
msg.submit(rinna.chat, [msg, chatbot, max_token_length, min_token_length, temperature], [msg, chatbot])
|
53 |
+
demo.launch()
|
FixedStar-icon.png
ADDED
user-icon.png
ADDED