|
import torch |
|
import pandas as pd |
|
from transformers import AutoTokenizer, AutoModel,BertTokenizer,BertModel |
|
import numpy as np |
|
import pickle |
|
|
|
import nltk |
|
nltk.download('stopwords') |
|
nltk.download('averaged_perceptron_tagger') |
|
nltk.download('wordnet') |
|
from nltk.stem import WordNetLemmatizer |
|
from nltk.tag import pos_tag |
|
from nltk.corpus import stopwords |
|
from pymystem3 import Mystem |
|
from functools import lru_cache |
|
import string |
|
import faiss |
|
from tqdm import tqdm |
|
DEVICE='cpu' |
|
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2") |
|
model = AutoModel.from_pretrained("cointegrated/rubert-tiny2") |
|
eng_stop_words = stopwords.words('english') |
|
with open('assets/russian.txt', 'r') as f: |
|
ru_stop_words = f.read() |
|
ru_stop_words=ru_stop_words.split('\n') |
|
allow="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyzАБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдеёжзийклмнопрстуфхцчшщъыьэюя0123456789-' \n\t" |
|
|
|
m= Mystem() |
|
def embed_bert_cls(text, model=model, tokenizer=tokenizer)->np.array: |
|
""" |
|
Встраивает входной текст с использованием модели на основе BERT. |
|
|
|
Аргументы: |
|
text (str): Входной текст для встраивания. |
|
model (torch.nn.Module): Модель на основе BERT для использования при встраивании. |
|
tokenizer (transformers.PreTrainedTokenizer): Токенизатор для токенизации текста. |
|
|
|
Возвращает: |
|
numpy.ndarray: Встроенное представление входного текста. |
|
""" |
|
|
|
t = tokenizer(text, padding=True, truncation=True, return_tensors='pt') |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
model_output = model(**{k: v.to(DEVICE) for k, v in t.items()}) |
|
|
|
|
|
embeddings = model_output.last_hidden_state[:, 0, :] |
|
|
|
|
|
embeddings = torch.nn.functional.normalize(embeddings) |
|
embeddings=embeddings[0].cpu().numpy() |
|
|
|
|
|
return embeddings |
|
|
|
def lems_eng(text): |
|
if type(text)==type('text'): |
|
text=text.split() |
|
wnl= WordNetLemmatizer() |
|
lemmatized= [] |
|
pos_map = { |
|
'NN': 'n', |
|
'NNS': 'n', |
|
'NNP': 'n', |
|
'NNPS': 'n', |
|
'VB': 'v', |
|
'VBD': 'v', |
|
'VBG': 'v', |
|
'VBN': 'v', |
|
'JJ': 'a', |
|
'JJR': 'a', |
|
'JJS': 'a', |
|
'RB': 'r', |
|
'RBR': 'r', |
|
'RBS': 'r', |
|
'PRP': 'n', |
|
'PRP$': 'n', |
|
'DT': 'n' |
|
} |
|
pos_tags = pos_tag(text) |
|
lemmas = [] |
|
for token, pos in pos_tags: |
|
pos = pos_map.get(pos,'n') |
|
lemma = wnl.lemmatize(token, pos=pos) |
|
lemmas.append(lemma) |
|
return ' '.join(lemmas) |
|
|
|
def lems_rus(texts): |
|
if type(texts)==type([]): |
|
texts=' '.join(texts) |
|
|
|
lemmas = m.lemmatize(texts) |
|
return ''.join(lemmas) |
|
def clean(text: str)-> str: |
|
|
|
|
|
text = ''.join(c for c in text if c in allow) |
|
text= text.split() |
|
text = [word for word in text if word.lower() not in ru_stop_words] |
|
text = [word for word in text if word.lower() not in eng_stop_words] |
|
return ' '.join(text) |
|
|
|
|
|
def improved_lemmatizer(texts,batch_size=1000): |
|
if type(texts)==type('text'): |
|
texts=texts.split() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
df=pd.read_csv('assets/final_and_lem.csv',index_col=0).reset_index(drop=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open('assets/embs.pickle', 'rb') as f: |
|
embs = pickle.load(f) |
|
|
|
embs =np.array(embs) |
|
print('Тип выхода:',type(embs),'Размер выхода: ',embs.shape) |
|
|
|
|
|
|
|
index=faiss.IndexFlatL2(embs.shape[1]) |
|
index.add(embs) |
|
@lru_cache() |
|
def find_similar(text, k=10): |
|
""" |
|
Находит похожие тексты на основе косинусного сходства. |
|
|
|
Аргументы: |
|
text (str): Входной текст для поиска похожих текстов. |
|
embeddings (numpy.ndarray): Предварительно вычисленные встроенные представления текстов. |
|
threshold (float): Порог, выше которого тексты считаются похожими. |
|
|
|
Возвращает: |
|
numpy.ndarray: Сходства между входным текстом и каждым текстом во встроенных представлениях. |
|
""" |
|
|
|
|
|
text_emb = embed_bert_cls(text) |
|
print('Текстовые эмбединги\t',text_emb ) |
|
text_emb = np.expand_dims(text_emb, axis=0) |
|
print(f'Тип поискового запроса: {type(text_emb)}\nРазмер полученного запроса: {text_emb.shape}') |
|
dist,idx=index.search(text_emb,k) |
|
print(f'Расстнояния:{dist}\tАйдишки{idx}') |
|
return dist.squeeze()[::-1],idx.squeeze()[::-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|