|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
import sentencepiece |
|
import transformers |
|
transformers.BertTokenizer = transformers.BertJapaneseTokenizer |
|
from sentence_transformers import SentenceTransformer |
|
from sentence_transformers import models |
|
from sentence_transformers import util |
|
import torch |
|
|
|
USER_NAME = "user" |
|
ASSISTANT_NAME = "assistant" |
|
|
|
transformer = models.Transformer('cl-tohoku/bert-base-japanese-whole-word-masking') |
|
pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False) |
|
model = SentenceTransformer(modules=[transformer, pooling]) |
|
model_df = pd.read_csv('dict.csv') |
|
|
|
def chat(user_msg): |
|
questions = list(model_df["question"][~model_df["question"].duplicated()]) |
|
corpus_embeddings = model.encode(questions, convert_to_tensor=True) |
|
query_embedding = model.encode(user_msg, convert_to_tensor=True) |
|
cos_scores = util.cos_sim(query_embedding , corpus_embeddings) |
|
top_results = torch.topk(cos_scores, k=1) |
|
return model_df["answer"][top_results[1][0].item()],top_results[0][0].item() |
|
|
|
|
|
|
|
iface = gr.Interface(fn=chat, inputs="text", outputs=["text","number"]) |
|
iface.launch() |