File size: 1,336 Bytes
dcb7fd4
 
 
 
 
c13ac75
dcb7fd4
 
 
 
 
 
 
c13ac75
 
dcb7fd4
 
 
 
 
d6b91ce
 
567f9cc
dcb7fd4
 
ef32437
bae201a
dcb7fd4
5135508
 
dcb7fd4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import gradio as gr
import numpy as np
import pandas as pd
import sentencepiece
import transformers
transformers.BertTokenizer = transformers.BertJapaneseTokenizer
from sentence_transformers import SentenceTransformer
from sentence_transformers import models
from sentence_transformers import util
import torch

USER_NAME = "user"
ASSISTANT_NAME = "assistant"

transformer = models.Transformer('cl-tohoku/bert-base-japanese-whole-word-masking')
pooling = models.Pooling(transformer.get_word_embedding_dimension(), pooling_mode_mean_tokens=True, pooling_mode_cls_token=False, pooling_mode_max_tokens=False)
model = SentenceTransformer(modules=[transformer, pooling])
model_df = pd.read_csv('dict.csv')

def chat(user_msg):
    questions = list(model_df["question"][~model_df["question"].duplicated()])
    corpus_embeddings = model.encode(questions, convert_to_tensor=True)
    query_embedding = model.encode(user_msg, convert_to_tensor=True)
    cos_scores = util.cos_sim(query_embedding , corpus_embeddings)
    top_results =  torch.topk(cos_scores, k=1)
    return model_df["answer"][top_results[1][0].item()],top_results[0][0].item()
    # return top_results[0][0],top_results[1][0]

# iface = gr.Interface(fn=chat, inputs="text", outputs="text")
iface = gr.Interface(fn=chat, inputs="text", outputs=["text","number"])
iface.launch()