File size: 2,120 Bytes
721488a 076d69d 721488a 02c4857 167d6cd 02c4857 721488a c499fb4 721488a c499fb4 02c4857 167d6cd 20be688 3b40e0a 721488a 3a042e6 721488a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import gradio as gr
import numpy as np
import pandas as pd
import sentencepiece
# from sentence_transformers import SentenceTransformer
# from sentence_transformers import models
from transformers import BertJapaneseTokenizer, BertModel
import torch
USER_NAME = "user"
ASSISTANT_NAME = "assistant"
MODEL_NAME = 'cl-tohoku/bert-base-japanese-whole-word-masking'
tokenizer = BertJapaneseTokenizer.from_pretrained(MODEL_NAME)
model = BertModel.from_pretrained(MODEL_NAME)
model_df = pd.read_csv('dict.csv')
def sentence_to_vector(model, tokenizer, sentence):
# 文を単語に区切って数字にラベル化
tokens = tokenizer(sentence)["input_ids"]
# BERTモデルの処理のためtensor型に変換
input = torch.tensor(tokens).reshape(1,-1)
# BERTモデルに入力し文のベクトルを取得
with torch.no_grad():
outputs = model(input, output_hidden_states=True)
last_hidden_state = outputs.last_hidden_state[0]
averaged_hidden_state = last_hidden_state.sum(dim=0) / len(last_hidden_state)
return averaged_hidden_state
def cosine_similarity(x1, x2, eps): # dimは単純化のため省略
w12 = torch.sum(x1 * x2)
w1 = torch.sum(x1 * x1)
w2 = torch.sum(x2 * x2)
n12 = (w1 * w2).clamp_min_(eps * eps).sqrt_()
score = w12 / n12
score = score.item()
return score
def calc_similarity(sentence1, sentence2):
sentence_vector1 = sentence_to_vector(model, tokenizer, sentence1)
sentence_vector2 = sentence_to_vector(model, tokenizer, sentence2)
score = cosine_similarity(sentence_vector1, sentence_vector2, 1e-8)
return score
def chat(user_msg):
sentence1 = user_msg
similar_value = 0
similar_word = ""
for i in range(60):
sentence2 = ""
value = 0
sentence2 = model_df["question"][i]
value = calc_similarity(sentence1, sentence2)
if value > similar_value:
similar_value = value
similar_word = model_df["answer"][i]
return similar_word,similar_value
iface = gr.Interface(fn=chat, inputs="text", outputs=["text","number"])
iface.launch() |