Spaces:
Sleeping
Sleeping
File size: 1,659 Bytes
46b2548 3fecf82 46b2548 043145c 46b2548 38201fe 46b2548 621ac26 46b2548 86a77c5 043145c 46b2548 38201fe 46b2548 2afb07e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
import json
import numpy as np
import os
from transformers import BertTokenizer
from rank_bm25 import BM25Okapi
import gradio as gr
HF_TOKEN = os.getenv('HF_TOKEN')
hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "budu_search_data")
tokenizer = BertTokenizer.from_pretrained("DeepPavlov/rubert-base-cased")
f = open('budu_search_syn_database.json')
database = json.load(f)
b25corpus = [x for x in database.values()]
b25local_names = [x for x in database.keys()]
bm25 = BM25Okapi(corpus=b25corpus)
def predict_bm25(service):
tokenized_query = tokenizer.tokenize(service.lower())
doc_scores = bm25.get_scores(tokenized_query)
sorted_doc_indices = doc_scores.argsort()[::-1]
sorted_local_names = np.array([b25local_names[i] for i in sorted_doc_indices])
scores = doc_scores[sorted_doc_indices]
scores_filtered = np.argwhere(scores>1).reshape(-1)
filtered_local_names = sorted_local_names[scores_filtered.tolist()].tolist()
if len(filtered_local_names)>3:
filtered_local_names = filtered_local_names[:3]
return filtered_local_names
demo = gr.Interface(fn=predict_bm25,inputs=gr.components.Textbox(label='Запрос пользователя'),
outputs=[gr.components.Textbox(label='Рекомендованные услуги')],
allow_flagging='auto',
flagging_callback = hf_writer,
examples=[
['кальций'],
['узи'],
['железо'],
['прием']])
if __name__ == "__main__":
demo.launch() |