chatcot / app10.py
pyamath's picture
Rename app.py to app10.py
38161ac
import streamlit as st
import numpy as np
import sentencepiece
from transformers import BertJapaneseTokenizer, BertModel
from sentence_transformers import SentenceTransformer
from sentence_transformers import models
import torch
import pandas as pd
# 日本語対応パッケージのインストール
st.title("質問箱")
# 定数定義
USER_NAME = "user"
ASSISTANT_NAME = "assistant"
MORIAGE_YAKU_NAME = "moriage_yaku"
MORIAGE_YAKU2_NAME = "moriage_yaku2"
MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
st.session_state.tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)
st.session_state.model = BertModel.from_pretrained(MODEL_NAME)
def sentence_to_vector(model, tokenizer, sentence):
# 文を単語に区切って数字にラベル化
st.session_state.tokens = tokenizer(sentence)["input_ids"]
# BERTモデルの処理のためtensor型に変換
st.session_state.input = torch.tensor(st.session_state.tokens).reshape(1,-1)
# BERTモデルに入力し文のベクトルを取得
with torch.no_grad():
st.session_state.outputs = model(st.session_state.input, output_hidden_states=True)
st.session_state.last_hidden_state = st.session_state.outputs.last_hidden_state[0]
st.session_state.averaged_hidden_state = st.session_state.last_hidden_state.sum(dim=0) / len(st.session_state.last_hidden_state)
return st.session_state.averaged_hidden_state
def calc_similarity(sentence1, sentence2):
st.session_state.sentence_vector1 = sentence_to_vector(st.session_state.model, st.session_state.tokenizer, sentence1)
st.session_state.sentence_vector2 = sentence_to_vector(st.session_state.model, st.session_state.tokenizer, sentence2)
st.session_state.score = torch.nn.functional.cosine_similarity(st.session_state.sentence_vector1, st.session_state.sentence_vector2, dim=0).detach().numpy().copy()
# チャットログを保存したセッション情報を初期化
if "chat_log" not in st.session_state:
st.session_state.chat_log = []
# ユーザーのアバターを設定
# img_moriyage_yaku2 = np.array(Image.open("moriage_yaku2.jpeg"))
avator_img_dict = {
MORIAGE_YAKU_NAME: "🎉",
# MORIAGE_YAKU2_NAME: img_moriyage_yaku2,
}
user_msg = st.chat_input("質問、要望等あれば入力してください")
if user_msg:
st.session_state.sentence1 = user_msg
st.session_state.similar_value = 0
st.session_state.similar_word = ""
st.session_state.df = pd.read_csv('dict.csv', index_col = 0)
for i in range(60):
st.session_state.sentence2 = st.session_state.df["question"][i]
st.session_state.value = calc_similarity(st.session_state.sentence1, st.session_state.sentence2)
if st.session_state.value > st.session_state.similar_value:
st.session_state.similar_value = st.session_state.value
st.session_state.similar_word = st.session_state.df["answer"][i]
# 以前のチャットログを表示
for chat in st.session_state.chat_log:
avator = avator_img_dict.get(chat["name"], None)
with st.chat_message(chat["name"], avatar=avator):
st.write(chat["msg"])
# 最新のメッセージを表示
assistant_msg = "もう一度入力してください"
moriage_yaku_msg = "アンコール!アンコール!"
moriage_yaku2_msg = "そっれ、アンコール!アンコール!"
with st.chat_message(USER_NAME):
st.write(user_msg)
with st.chat_message(ASSISTANT_NAME):
st.write(a)
with st.chat_message(MORIAGE_YAKU_NAME, avatar=avator_img_dict[MORIAGE_YAKU_NAME]):
st.write(moriage_yaku_msg)
# セッションにチャットログを追加
st.session_state.chat_log.append({"name": USER_NAME, "msg": user_msg})
st.session_state.chat_log.append({"name": ASSISTANT_NAME, "msg": user_msg})
st.session_state.chat_log.append({"name": MORIAGE_YAKU_NAME, "msg": user_msg})
# st.session_state.chat_log.append({"name": MORIAGE_YAKU2_NAME, "msg": user_msg})