warleagle's picture
Update app.py
588b6a9 verified
#%%
import pandas as pd
import numpy as np
import torch
import json
import re
from sentence_transformers.util import cos_sim
from sentence_transformers import SentenceTransformer
from sklearn.feature_extraction.text import CountVectorizer
import gradio as gr
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
#%%
model = SentenceTransformer('sentence-transformers/multi-qa-distilbert-cos-v1')
russian_stopwords = stopwords.words('russian') + ['ВАШ']
with open("top_150_symps_by_spec.json", 'r') as f:
symps = json.load(f)
with open("embeddings.npy", 'rb') as f:
embs = np.load(f)
def remove_numbers(text):
text = re.sub(r'\d+', '', text)
text = re.sub(r'[^\w\s]', '', text)
return text.strip()
vectorizer = CountVectorizer(ngram_range=(1, 3),
stop_words=russian_stopwords,
preprocessor=remove_numbers,
)
def get_symptomps_v2(text, treshold = 0.7):
try:
if isinstance(text, str):
text = [text]
X = vectorizer.fit_transform(text)
text_emb = model.encode(vectorizer.get_feature_names_out(), batch_size=64)
cos_sim_m = cos_sim(text_emb, embs).numpy()
cos_sim_m = np.where(cos_sim_m > treshold, cos_sim_m, -1)
arg_max_idx = np.argmax(cos_sim_m, axis=1)
outputs = []
for idx, cos_sim_row in zip(arg_max_idx, cos_sim_m):
if cos_sim_row[idx] > 0:
outputs.append(symps[idx])
if len(outputs) == 0:
return ['Симптомы не определены']
return np.unique(outputs).tolist()
except:
return ['Симптомы не определены']
#%%
gradio_app = gr.Interface(
get_symptomps_v2,
inputs=['text',
gr.Slider(minimum=0, maximum=1, step=0.05, label="Порог релевантности", value=0.8)],
outputs=[gr.JSON(label='Симптомы: ')],
description="Введите услугу:"
)
if __name__ == "__main__":
gradio_app.launch()
# %%