find_my_book / find.py
Shchushch's picture
Upload 6 files
7576ded
raw
history blame
3.38 kB
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModel
import numpy as np
tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
model = AutoModel.from_pretrained("cointegrated/rubert-tiny2")
def embed_bert_cls(text, model=model, tokenizer=tokenizer):
"""
Встраивает входной текст с использованием модели на основе BERT.
Аргументы:
text (str): Входной текст для встраивания.
model (torch.nn.Module): Модель на основе BERT для использования при встраивании.
tokenizer (transformers.PreTrainedTokenizer): Токенизатор для токенизации текста.
Возвращает:
numpy.ndarray: Встроенное представление входного текста.
"""
# Токенизируем текст и преобразуем его в PyTorch тензоры
t = tokenizer(text, padding=True, truncation=True, return_tensors='pt')
# Отключаем вычисление градиентов
with torch.no_grad():
# Пропускаем тензоры через модель
model_output = model(**{k: v.to(model.device) for k, v in t.items()})
# Извлекаем последний скрытый состояние из выходных данных модели
embeddings = model_output.last_hidden_state[:, 0, :]
# Нормализуем встроенные представления
embeddings = torch.nn.functional.normalize(embeddings)
embeddings=embeddings[0].cpu().numpy()
# Преобразуем встроенные представления в массив numpy и возвращаем первый элемент
return embeddings
df=pd.read_csv('books_sample.csv',index_col=0)
embs=[]
for annotation in df['annotation']:
# embd=
#print(embd)
embs.append(embed_bert_cls(annotation))
#embs.append(embed_bert_cls(annotation))
embs =np.array(embs)
def find_similar(text, embeddings=embs, threshold=0.5):
"""
Находит похожие тексты на основе косинусного сходства.
Аргументы:
text (str): Входной текст для поиска похожих текстов.
embeddings (numpy.ndarray): Предварительно вычисленные встроенные представления текстов.
threshold (float): Порог, выше которого тексты считаются похожими.
Возвращает:
numpy.ndarray: Сходства между входным текстом и каждым текстом во встроенных представлениях.
"""
# Встраиваем входной текст
embedding = embed_bert_cls(text)
# Вычисляем косинусное сходство между встроенным представлением входного текста и всеми встроенными представлениями
similarities = embeddings.dot(embedding)
sorted_indeces=similarities.argsort()[::-1]#[::1]
return similarities,sorted_indeces
print(find_similar('пук',embeddings=embs))