import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM class RinnaTalk(): def __init__(self, tokenizer=None, model=None): self.prompt = '' # 事前にmodelとtokenizerを呼んでおく self.tokenizer = AutoTokenizer.from_pretrained("models", use_fast=False) if tokenizer is None else tokenizer self.model = AutoModelForCausalLM.from_pretrained("models", torch_dtype=torch.float16) if model is None else model def chat(self, message: str, chat_history: list, max_token_length: int = 128, min_token_length: int = 10, temperature: float = 0.8): # チャット履歴をクリアした際にpromptもクリアさせるため if len(chat_history) == 0: self.prompt = '' self.prompt += f'ユーザー: {message}\nシステム: ' token_ids = self.tokenizer.encode(self.prompt, add_special_tokens=False, return_tensors="pt") with torch.no_grad(): output_ids = self.model.generate( token_ids, max_new_tokens=max_token_length, min_new_tokens=min_token_length, top_p=top_p, top_k=top_k, do_sample=do_sample, temperature=temperature, num_beam=num_beam, pad_token_id=self.tokenizer.pad_token_id, bos_token_id=self.tokenizer.bos_token_id, eos_token_id=self.tokenizer.eos_token_id ) output = self.tokenizer.decode(output_ids.tolist()[0]) latest_reply = output.split('')[-1].rstrip('') chat_history.append([message, latest_reply]) self.prompt += f'{latest_reply}\n' return "", chat_history rinna = RinnaTalk() with gr.Blocks() as demo: chatbot = gr.Chatbot( label="FixedStar-DebugChat", show_copy_button=True, show_share_button=True, avatar_images=["user-icon.png", "FixedStar-icon.png"] ) max_token_length = gr.Slider( value=512, minimum=10, maximum=512, label='max_token_length' ) min_token_length = gr.Slider( value=1, minimum=1, maximum=512, label='min_token_length' ) top_p = gr.Slider( value=0.75, minimum=0, maximum=1, label='top_p' ) top_k = gr.Slider( value=40, minimum=1, maximum=1000, label='top_k' ) temperature = gr.Slider( value=0.9, minimum=0, maximum=1, scale=0.01, label='temperature' ) do_sample = gr.Checkbox( value=True, label='do_sample' ) num_beam = gr.Slider( value=0, minimum=0, maximum=100, label='num_beam' ) msg = gr.Textbox() clear = gr.ClearButton([msg, chatbot]) msg.submit(rinna.chat, [msg, chatbot, max_token_length, min_token_length, temperature], [msg, chatbot]) demo.launch()