nn_ext / resources /functions.py
ElijahDi's picture
Upload 7 files
cb99840 verified
raw
history blame
No virus
1.52 kB
import re
import string
import pandas as pd
# import numpy as np
# import torch
import nltk
import pymorphy2
from nltk.corpus import stopwords
nltk.download('stopwords')
from sentence_transformers import SentenceTransformer, util
stop_words = set(stopwords.words('russian'))
morph = pymorphy2.MorphAnalyzer()
model = SentenceTransformer('cointegrated/rubert-tiny2')
def data_preprocessing_hard(text: str) -> str:
text = str(text)
text = text.lower()
text = re.sub('<.*?>', '', text)
text = re.sub(r'[^а-яА-Я\s]', '', text)
text = ''.join([c for c in text if c not in string.punctuation])
text = ' '.join([word for word in text.split() if word not in stop_words])
# text = ''.join([char for char in text if not char.isdigit()])
text = ' '.join([morph.parse(word)[0].normal_form for word in text.split()])
return text
def filter(df: pd.DataFrame, ganre_list: list):
filtered_df = df[df['ganres'].apply(lambda x: any(g in ganre_list for g in(x)))]
filt_ind = filtered_df.index.to_list()
return filt_ind
def recommend(text: str, embeddings, top_k):
query_embeddings = model.encode([data_preprocessing_hard(text)], convert_to_tensor=True)
embeddings = embeddings.to("cpu")
embeddings = util.normalize_embeddings(embeddings)
query_embeddings = query_embeddings.to("cpu")
query_embeddings = util.normalize_embeddings(query_embeddings)
hits = util.semantic_search(query_embeddings, embeddings, top_k, score_function=util.dot_score)
return hits