{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import torch\n", "from transformers import AutoTokenizer, AutoModel\n", "import re\n", "import string\n", "import numpy as np\n", "from sklearn.metrics.pairwise import cosine_similarity\n", "import streamlit as st\n", "import faiss\n", "\n" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "url = '/clean_mail_movie.csv'\n", "\n", "df = pd.read_csv(url)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "dataset = df['concat2embedding'].tolist() # Это объединённый столбец\n", "titles = df['movie_title'].tolist()\n", "images = df['image_url'].tolist()\n", "descr = df['description'].tolist()\n", "links = df['page_url'].tolist()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def clean(text):\n", " text = text.lower() # Нижний регистр\n", " # text = re.sub(r'\\d+', ' ', text) # Удаляем числа\n", " # text = text.translate(str.maketrans('', '', string.punctuation)) # Удаляем пунктуацию\n", " text = re.sub(r'\\s+', ' ', text) # Удаляем лишние пробелы\n", " text = text.strip() # Удаляем начальные и конечные пробелы\n", " # text = re.sub(r'\\b\\w{1,2}\\b', '', text) # Удаляем слова длиной менее 3 символов\n", " # Дополнительные шаги, которые могут быть полезны в данном контексте:\n", " # text = re.sub(r'\\b\\w+\\b', '', text) # Удаляем отдельные слова (без чисел и знаков препинания)\n", " # text = ' '.join([word for word in text.split() if word not in stop_words]) # Удаляем стоп-слова\n", " return text\n", "\n", "\n", "cleaned_text = []\n", "\n", "for text in dataset:\n", " cleaned_text.append(clean(text))" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# pip install transformers sentencepiece\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"cointegrated/rubert-tiny2\")\n", "model = AutoModel.from_pretrained(\"cointegrated/rubert-tiny2\")\n", "# model.cuda() # uncomment it if you have a GPU" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Дефолтная функция, шла в комплекте с моделью\n", "\n", "def embed_bert_cls(text, model, tokenizer):\n", " t = tokenizer(text, padding=True, truncation=True, return_tensors='pt', max_length=1024) # Модель сама создаёт пэддинги и маску.\n", " with torch.no_grad():\n", " model_output = model(**{k: v.to(model.device) for k, v in t.items()})\n", " embeddings = model_output.last_hidden_state[:, 0, :]\n", " embeddings = torch.nn.functional.normalize(embeddings)\n", " return embeddings[0].cpu().numpy()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Векторизация отзывов\n", "text_embeddings = np.array([embed_bert_cls(text, model, tokenizer) for text in cleaned_text])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Создание FAISS индекса после определения text_embeddings\n", "dimension = text_embeddings.shape[1]\n", "index = faiss.IndexFlatL2(dimension)\n", "index.add(text_embeddings.astype('float32'))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['mail_embeddings.joblib']" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from joblib import dump, load\n", "\n", "# Сохранение эмбеддингов\n", "dump(text_embeddings, 'mail_embeddings.joblib')" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "# Сохранение индекса\n", "faiss.write_index(index, \"mail_faiss_index.index\")" ] } ], "metadata": { "kernelspec": { "display_name": "pytorch_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }