|
import streamlit as st |
|
import numpy as np |
|
import pandas as pd |
|
import sentencepiece |
|
from transformers import BertJapaneseTokenizer, BertModel |
|
from sentence_transformers import SentenceTransformer |
|
from sentence_transformers import models |
|
import torch |
|
from torch.nn.functional import cosine_similarity |
|
|
|
|
|
st.title("質問箱") |
|
|
|
|
|
USER_NAME = "user" |
|
ASSISTANT_NAME = "assistant" |
|
MORIAGE_YAKU_NAME = "moriage_yaku" |
|
MORIAGE_YAKU2_NAME = "moriage_yaku2" |
|
MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking' |
|
|
|
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME) |
|
model = BertModel.from_pretrained(MODEL_NAME) |
|
|
|
def sentence_to_vector(model, tokenizer, sentence): |
|
|
|
tokens = tokenizer(sentence)["input_ids"] |
|
|
|
input = torch.tensor(tokens).reshape(1,-1) |
|
|
|
with torch.no_grad(): |
|
outputs = model(input, output_hidden_states=True) |
|
last_hidden_state = outputs.last_hidden_state[0] |
|
st.session_state.averaged_hidden_state = last_hidden_state.sum(dim=0) / len(last_hidden_state) |
|
return st.session_state.averaged_hidden_state |
|
|
|
def calc_similarity(sentence1, sentence2): |
|
st.session_state.sentence_vector1 = sentence_to_vector(model, tokenizer, sentence1) |
|
st.session_state.sentence_vector2 = sentence_to_vector(model, tokenizer, sentence2) |
|
st.session_state.score = torch.nn.functional.cosine_similarity(st.session_state.sentence_vector1, st.session_state.sentence_vector2, dim=0).detach().numpy().copy() |
|
st.write(st.session_state.score) |
|
|
|
if "chat_log" not in st.session_state: |
|
st.session_state.chat_log = [] |
|
|
|
|
|
|
|
avator_img_dict = { |
|
MORIAGE_YAKU_NAME: "🎉", |
|
|
|
} |
|
|
|
user_msg = st.chat_input("質問、要望等あれば入力してください") |
|
if user_msg: |
|
st.session_state.sentence1 = user_msg |
|
st.session_state.sentence2 = "名前はまだない。" |
|
st.session_state.similar_value = 0 |
|
st.session_state.similar_word = "" |
|
st.session_state.df = pd.read_csv('dict.csv') |
|
st.session_state.value = calc_similarity(st.session_state.sentence1, st.session_state.sentence2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for chat in st.session_state.chat_log: |
|
avator = avator_img_dict.get(chat["name"], None) |
|
with st.chat_message(chat["name"], avatar=avator): |
|
st.write(chat["msg"]) |
|
|
|
|
|
assistant_msg = "もう一度入力してください" |
|
moriage_yaku_msg = "アンコール!アンコール!" |
|
moriage_yaku2_msg = "そっれ、アンコール!アンコール!" |
|
with st.chat_message(USER_NAME): |
|
st.write(user_msg) |
|
with st.chat_message(ASSISTANT_NAME): |
|
st.write(user_msg) |
|
with st.chat_message(MORIAGE_YAKU_NAME, avatar=avator_img_dict[MORIAGE_YAKU_NAME]): |
|
st.write(moriage_yaku_msg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.session_state.chat_log.append({"name": USER_NAME, "msg": user_msg}) |
|
st.session_state.chat_log.append({"name": ASSISTANT_NAME, "msg": user_msg}) |
|
st.session_state.chat_log.append({"name": MORIAGE_YAKU_NAME, "msg": user_msg}) |
|
|
|
|