ほしゆめ commited on
Commit
0ff13f4
1 Parent(s): b00a232

Upload 3 files

Browse files
Files changed (3) hide show
  1. ChatSystem.py +53 -0
  2. FixedStar-icon.png +0 -0
  3. user-icon.png +0 -0
ChatSystem.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM
4
+
5
+ class RinnaTalk():
6
+ def __init__(self, tokenizer=None, model=None):
7
+ self.prompt = ''
8
+ # 事前にmodelとtokenizerを呼んでおく
9
+ self.tokenizer = AutoTokenizer.from_pretrained("Hoshiyume/FixedStar-BETA", use_fast=False) if tokenizer is None else tokenizer
10
+ self.model = AutoModelForCausalLM.from_pretrained("Hoshiyume/FixedStar-BETA", torch_dtype=torch.float16) if model is None else model
11
+
12
+ def chat(self, message: str, chat_history: list, max_token_length: int = 128, min_token_length: int = 10, temperature: float = 0.8):
13
+ # チャット履歴をクリアした際にpromptもクリアさせるため
14
+ if len(chat_history) == 0:
15
+ self.prompt = ''
16
+
17
+ self.prompt += f'ユーザー: {message}\nシステム: '
18
+ token_ids = self.tokenizer.encode(self.prompt, add_special_tokens=False, return_tensors="pt")
19
+ with torch.no_grad():
20
+ output_ids = self.model.generate(
21
+ token_ids,
22
+ max_new_tokens=max_token_length,
23
+ min_new_tokens=min_token_length,
24
+ top_p=top_p,
25
+ top_k=top_k,
26
+ do_sample=do_sample,
27
+ temperature=temperature,
28
+ num_beam=num_beam,
29
+ pad_token_id=self.tokenizer.pad_token_id,
30
+ bos_token_id=self.tokenizer.bos_token_id,
31
+ eos_token_id=self.tokenizer.eos_token_id
32
+ )
33
+ output = self.tokenizer.decode(output_ids.tolist()[0])
34
+ latest_reply = output.split('<NL>')[-1].rstrip('</s>')
35
+ chat_history.append([message, latest_reply])
36
+ self.prompt += f'{latest_reply}\n'
37
+
38
+ return "", chat_history
39
+ rinna = RinnaTalk()
40
+
41
+ with gr.Blocks() as demo:
42
+ chatbot = gr.Chatbot( label="FixedStar-DebugChat", show_copy_button=True, show_share_button=True, avatar_images=["user-icon.png", "FixedStar-icon.png"] )
43
+ max_token_length = gr.Slider( value=512, minimum=10, maximum=512, label='max_token_length' )
44
+ min_token_length = gr.Slider( value=1, minimum=1, maximum=512, label='min_token_length' )
45
+ top_p = gr.Slider( value=0.75, minimum=0, maximum=1, label='top_p' )
46
+ top_k = gr.Slider( value=40, minimum=1, maximum=1000, label='top_k' )
47
+ temperature = gr.Slider( value=0.9, minimum=0, maximum=1, scale=0.01, value=0.8, label='temperature' )
48
+ do_sample = gr.Checkbox( value=True, label='do_sample' )
49
+ num_beam = gr.Slider( value=0, minimum=0, maximum=100, label='num_beam' )
50
+ msg = gr.Textbox()
51
+ clear = gr.ClearButton([msg, chatbot])
52
+ msg.submit(rinna.chat, [msg, chatbot, max_token_length, min_token_length, temperature], [msg, chatbot])
53
+ demo.launch()
FixedStar-icon.png ADDED
user-icon.png ADDED