{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "48d89cbc-3660-49b9-be37-087c1b05bf78", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a776258b88794927b45bbfc8fc9bdcb0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Encoding images to latents: 0%| | 0/60000 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "🔹 Текстовое описание: 4\n", "ok\n" ] } ], "source": [ "from datasets import load_from_disk\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "from PIL import Image \n", "\n", "dtype = torch.float16\n", "\n", "# Загружаем сохраненный датасет\n", "loaded_dataset = load_from_disk(save_path)\n", "\n", "# Проверяем структуру датасета\n", "#print(\"Структура датасета:\", loaded_dataset.features)\n", "\n", "# Выбираем пример для демонстрации\n", "example = loaded_dataset[2]\n", "\n", "# Выводим информацию о примере\n", "print(\"Форма латентного представления:\", np.array(example[\"vae\"]).shape)\n", "print(\"embedding shape:\", np.array(example[\"embeddings\"]).shape)\n", "\n", "# Преобразуем латентное представление в тензор PyTorch\n", "latent_tensor = torch.tensor(example[\"vae\"], dtype=dtype).unsqueeze(0).to(device)\n", "\n", "# Декодируем латентное представление обратно в изображение\n", "with torch.no_grad():\n", " #reconstructed_image = vae.decode(latent_tensor).sample # Результат — тензор\n", " latent = (latent_tensor.detach() / vae.config.scaling_factor) + vae.config.shift_factor\n", " reconstructed_image = vae.decode(latent).sample\n", "\n", "# Переносим тензор на CPU и преобразуем в NumPy массив\n", "\n", "reconstructed_image = reconstructed_image.squeeze(0).cpu().numpy() # Удаляем размерность батча\n", "\n", "# Переносим каналы в правильный формат (CHW -> HWC) и нормализуем значения пикселей\n", "reconstructed_image = np.transpose(reconstructed_image, (1, 2, 0))\n", "reconstructed_image = (reconstructed_image + 1) / 2 # Нормализация в диапазон [0, 1]\n", "\n", "# Преобразуем тип данных к float32\n", "reconstructed_image = reconstructed_image.astype(np.float32)\n", "reconstructed_image = np.clip(reconstructed_image, 0.0, 1.0)\n", "\n", "# Отображаем восстановленное изображение\n", "plt.figure(figsize=(4, 4))\n", "plt.imshow(reconstructed_image)\n", "plt.title(f\"Reconstructed Image\")\n", "plt.axis(\"off\")\n", "plt.show()\n", "print(f\"🔹 Текстовое описание: {example['text']}\")\n", "print(\"ok\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "3818a9a7-f72c-42ea-9805-43bfd4b214a0", "metadata": {}, "outputs": [], "source": [ "#!pip install matplotlib" ] }, { "cell_type": "code", "execution_count": 2, "id": "f5336198-8925-4e03-ad81-72ee0eb5248a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "HF_HOME is set to: cache\n", "HF_DATASETS_CACHE is set to: cache/datasets\n" ] } ], "source": [ "# Проверьте переменные окружения\n", "hf_home = os.environ.get(\"HF_HOME\")\n", "hf_datasets_cache = os.environ.get(\"HF_DATASETS_CACHE\")\n", "\n", "if hf_home:\n", " print(f\"HF_HOME is set to: {hf_home}\")\n", "if hf_datasets_cache:\n", " print(f\"HF_DATASETS_CACHE is set to: {hf_datasets_cache}\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "94d7f0fe-3b27-4d08-ba35-ba7de096c635", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cache/datasets\n" ] } ], "source": [ "print(hf_datasets_cache)" ] }, { "cell_type": "code", "execution_count": 2, "id": "8206ec60-e828-4a56-b902-3b44b99b563c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting transformers\n", " Downloading transformers-4.49.0-py3-none-any.whl.metadata (44 kB)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from transformers) (3.13.1)\n", "Requirement already satisfied: huggingface-hub<1.0,>=0.26.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.28.1)\n", "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (1.26.3)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from transformers) (24.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (6.0.2)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers) (2024.11.6)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from transformers) (2.32.3)\n", "Collecting tokenizers<0.22,>=0.21 (from transformers)\n", " Downloading tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)\n", "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.11/dist-packages (from transformers) (0.5.2)\n", "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.11/dist-packages (from transformers) (4.67.1)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (2024.2.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub<1.0,>=0.26.0->transformers) (4.12.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2.2.3)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->transformers) (2024.8.30)\n", "Downloading transformers-4.49.0-py3-none-any.whl (10.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.0/10.0 MB\u001b[0m \u001b[31m98.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hDownloading tokenizers-0.21.0-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.0/3.0 MB\u001b[0m \u001b[31m94.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: tokenizers, transformers\n", "Successfully installed tokenizers-0.21.0 transformers-4.49.0\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install -U transformers --break-system-packages" ] }, { "cell_type": "code", "execution_count": 4, "id": "48e6e1d3-310c-4f67-a27b-90a54036dc4d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting flash-attn\n", " Downloading flash_attn-2.7.4.post1.tar.gz (6.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.0/6.0 MB\u001b[0m \u001b[31m94.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25ldone\n", "\u001b[?25hRequirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (from flash-attn) (2.4.1+cu124)\n", "Collecting einops (from flash-attn)\n", " Downloading einops-0.8.1-py3-none-any.whl.metadata (13 kB)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.13.1)\n", "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (4.12.2)\n", "Requirement already satisfied: sympy in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (1.12)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.2.1)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.1.3)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (2024.2.0)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.2.65 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.2.65)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.0.44 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (11.2.0.44)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.119 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (10.3.5.119)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.0.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (11.6.0.99)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.0.142 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.3.0.142)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.20.5 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (2.20.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.99 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (12.4.99)\n", "Requirement already satisfied: triton==3.0.0 in /usr/local/lib/python3.11/dist-packages (from torch->flash-attn) (3.0.0)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch->flash-attn) (2.1.5)\n", "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.11/dist-packages (from sympy->torch->flash-attn) (1.3.0)\n", "Downloading einops-0.8.1-py3-none-any.whl (64 kB)\n", "Building wheels for collected packages: flash-attn\n", " Building wheel for flash-attn (setup.py) ... \u001b[?25ldone\n", "\u001b[?25h Created wheel for flash-attn: filename=flash_attn-2.7.4.post1-cp311-cp311-linux_x86_64.whl size=187805408 sha256=92cf49e6f66795b6934cec0cba526ed6e45d3313de3f905d45df8773f19092a9\n", " Stored in directory: /root/.cache/pip/wheels/3d/88/d8/284b89f56af7d5bf366b10d6b8e251ac8a7c7bf3f04203fb4f\n", "Successfully built flash-attn\n", "Installing collected packages: einops, flash-attn\n", "Successfully installed einops-0.8.1 flash-attn-2.7.4.post1\n", "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager, possibly rendering your system unusable. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv. Use the --root-user-action option if you know what you are doing and want to suppress this warning.\u001b[0m\u001b[33m\n", "\u001b[0m" ] } ], "source": [ "!pip install flash-attn --no-build-isolation" ] }, { "cell_type": "code", "execution_count": null, "id": "24b9dd4e-789a-4162-b5b2-45c26f9b7504", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.10" } }, "nbformat": 4, "nbformat_minor": 5 }