File size: 1,659 Bytes
46b2548
 
3fecf82
46b2548
 
 
 
 
043145c
 
 
46b2548
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38201fe
46b2548
621ac26
 
46b2548
 
 
 
86a77c5
043145c
46b2548
 
38201fe
46b2548
 
 
 
2afb07e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import json
import numpy as np
import os

from transformers import BertTokenizer
from rank_bm25 import BM25Okapi
import gradio as gr

HF_TOKEN = os.getenv('HF_TOKEN')
hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "budu_search_data")

tokenizer = BertTokenizer.from_pretrained("DeepPavlov/rubert-base-cased")

f = open('budu_search_syn_database.json')

database = json.load(f)

b25corpus = [x for x in database.values()]
b25local_names = [x for x in database.keys()]
bm25 = BM25Okapi(corpus=b25corpus)

def predict_bm25(service):
    tokenized_query = tokenizer.tokenize(service.lower())

    doc_scores = bm25.get_scores(tokenized_query)
    sorted_doc_indices = doc_scores.argsort()[::-1]
                        
    sorted_local_names = np.array([b25local_names[i] for i in sorted_doc_indices])
    scores = doc_scores[sorted_doc_indices]
    scores_filtered = np.argwhere(scores>1).reshape(-1)
    filtered_local_names = sorted_local_names[scores_filtered.tolist()].tolist()
    if len(filtered_local_names)>3:
        filtered_local_names = filtered_local_names[:3]
    return filtered_local_names

demo = gr.Interface(fn=predict_bm25,inputs=gr.components.Textbox(label='Запрос пользователя'),
                    outputs=[gr.components.Textbox(label='Рекомендованные услуги')],
                    allow_flagging='auto',
                    flagging_callback = hf_writer,
                    examples=[
                        ['кальций'],
                        ['узи'],
                        ['железо'],
                        ['прием']])

if __name__ == "__main__":
    demo.launch()