chatbot / app.py
pyamath's picture
Update app.py
ef32437 verified
import gradio as gr
import numpy as np
import pandas as pd
import sentencepiece
import transformers
transformers.BertTokenizer = transformers.BertJapaneseTokenizer
from sentence_transformers import SentenceTransformer
from sentence_transformers import models
from sentence_transformers import util
import torch
USER_NAME = "user"
ASSISTANT_NAME = "assistant"
transformer = models.Transformer('cl-tohoku/bert-base-japanese-whole-word-masking')
pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False)
model = SentenceTransformer(modules=[transformer, pooling])
model_df = pd.read_csv('dict.csv')
def chat(user_msg):
questions = list(model_df["question"][~model_df["question"].duplicated()])
corpus_embeddings = model.encode(questions, convert_to_tensor=True)
query_embedding = model.encode(user_msg, convert_to_tensor=True)
cos_scores = util.cos_sim(query_embedding , corpus_embeddings)
top_results = torch.topk(cos_scores, k=1)
return model_df["answer"][top_results[1][0].item()],top_results[0][0].item()
# return top_results[0][0],top_results[1][0]
# iface = gr.Interface(fn=chat, inputs="text", outputs="text")
iface = gr.Interface(fn=chat, inputs="text", outputs=["text","number"])
iface.launch()