movies_search / EmbeddingGenerator.py
AlexandraGulamova's picture
Initial commit with all project files
e8d59a6
#EmbeddingGenerator.py
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="transformers.models.bert")
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class EmbeddingGenerator:
def __init__(self, pavlov_model_name="DeepPavlov/rubert-base-cased", sentence_transformer_model_name="cointegrated/rubert-tiny2"):
"""
Инициализирует токенизатор и модели для генерации эмбеддингов.
Args:
pavlov_model_name (str): Название модели для загрузки Pavlov модели.
sentence_transformer_model_name (str): Название модели SentenceTransformer для генерации эмбеддингов.
"""
self.pavlov_tokenizer = AutoTokenizer.from_pretrained(pavlov_model_name, ignore_mismatched_sizes=True)
self.pavlov_model = AutoModel.from_pretrained(pavlov_model_name, ignore_mismatched_sizes=True)
self.sentence_transformer_model = SentenceTransformer(sentence_transformer_model_name)
def generate_embeddings(self, texts, method="pavlov"):
"""
Генерирует эмбеддинги для списка текстов с использованием выбранного метода.
Args:
texts (list of str): Список текстов для генерации эмбеддингов.
method (str): Метод генерации эмбеддингов: "pavlov" или "rubert_tiny2".
Returns:
np.ndarray: Эмбеддинги текстов.
"""
if method == "pavlov":
# Генерация эмбеддингов с использованием Pavlov модели
inputs = self.pavlov_tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = self.pavlov_model(**inputs)
# Mean pooling
embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()
elif method == "rubert_tiny2":
# Генерация эмбеддингов с использованием SentenceTransformer
embeddings = self.sentence_transformer_model.encode(texts, show_progress_bar=False)
else:
raise ValueError("Unsupported method. Choose 'pavlov' or 'rubert_tiny2'.")
return embeddings