import streamlit as st import numpy as np import sentencepiece from transformers import BertJapaneseTokenizer, BertModel from sentence_transformers import SentenceTransformer from sentence_transformers import models import torch import pandas as pd # 日本語対応パッケージのインストール st.title("質問箱") # 定数定義 USER_NAME = "user" ASSISTANT_NAME = "assistant" MORIAGE_YAKU_NAME = "moriage_yaku" MORIAGE_YAKU2_NAME = "moriage_yaku2" MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking' st.session_state.tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME) st.session_state.model = BertModel.from_pretrained(MODEL_NAME) def sentence_to_vector(model, tokenizer, sentence): # 文を単語に区切って数字にラベル化 st.session_state.tokens = tokenizer(sentence)["input_ids"] # BERTモデルの処理のためtensor型に変換 st.session_state.input = torch.tensor(st.session_state.tokens).reshape(1,-1) # BERTモデルに入力し文のベクトルを取得 with torch.no_grad(): st.session_state.outputs = model(st.session_state.input, output_hidden_states=True) st.session_state.last_hidden_state = st.session_state.outputs.last_hidden_state[0] st.session_state.averaged_hidden_state = st.session_state.last_hidden_state.sum(dim=0) / len(st.session_state.last_hidden_state) return st.session_state.averaged_hidden_state def calc_similarity(sentence1, sentence2): st.session_state.sentence_vector1 = sentence_to_vector(st.session_state.model, st.session_state.tokenizer, sentence1) st.session_state.sentence_vector2 = sentence_to_vector(st.session_state.model, st.session_state.tokenizer, sentence2) st.session_state.score = torch.nn.functional.cosine_similarity(st.session_state.sentence_vector1, st.session_state.sentence_vector2, dim=0).detach().numpy().copy() # チャットログを保存したセッション情報を初期化 if "chat_log" not in st.session_state: st.session_state.chat_log = [] # ユーザーのアバターを設定 # img_moriyage_yaku2 = np.array(Image.open("moriage_yaku2.jpeg")) avator_img_dict = { MORIAGE_YAKU_NAME: "🎉", # MORIAGE_YAKU2_NAME: img_moriyage_yaku2, } user_msg = st.chat_input("質問、要望等あれば入力してください") if user_msg: st.session_state.sentence1 = user_msg st.session_state.similar_value = 0 st.session_state.similar_word = "" st.session_state.df = pd.read_csv('dict.csv', index_col = 0) for i in range(60): st.session_state.sentence2 = st.session_state.df["question"][i] st.session_state.value = calc_similarity(st.session_state.sentence1, st.session_state.sentence2) if st.session_state.value > st.session_state.similar_value: st.session_state.similar_value = st.session_state.value st.session_state.similar_word = st.session_state.df["answer"][i] # 以前のチャットログを表示 for chat in st.session_state.chat_log: avator = avator_img_dict.get(chat["name"], None) with st.chat_message(chat["name"], avatar=avator): st.write(chat["msg"]) # 最新のメッセージを表示 assistant_msg = "もう一度入力してください" moriage_yaku_msg = "アンコール!アンコール!" moriage_yaku2_msg = "そっれ、アンコール!アンコール!" with st.chat_message(USER_NAME): st.write(user_msg) with st.chat_message(ASSISTANT_NAME): st.write(a) with st.chat_message(MORIAGE_YAKU_NAME, avatar=avator_img_dict[MORIAGE_YAKU_NAME]): st.write(moriage_yaku_msg) # セッションにチャットログを追加 st.session_state.chat_log.append({"name": USER_NAME, "msg": user_msg}) st.session_state.chat_log.append({"name": ASSISTANT_NAME, "msg": user_msg}) st.session_state.chat_log.append({"name": MORIAGE_YAKU_NAME, "msg": user_msg}) # st.session_state.chat_log.append({"name": MORIAGE_YAKU2_NAME, "msg": user_msg})