{ "cells": [ { "cell_type": "code", "execution_count": 3, "id": "873b1354-b85f-4c5b-9163-95190f07b39a", "metadata": {}, "outputs": [], "source": [ "import os\n", "import zipfile\n", "from PIL import Image\n", "from io import BytesIO\n", "import numpy as np\n", "from datasets import load_dataset\n", "import torch\n", "from diffusers import AutoencoderKL, UNet2DModel, UNet2DConditionModel\n", "import pickle" ] }, { "cell_type": "code", "execution_count": 2, "id": "35949720-3e01-43b0-8487-a1b2131d5a9e", "metadata": {}, "outputs": [], "source": [ "def preprocess_image(image):\n", " w, h = image.size\n", " w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32\n", " image = image.resize((w, h), resample=Image.Resampling.LANCZOS)\n", " image = np.array(image).astype(np.float32) / 255.0\n", " image = image[None].transpose(0, 3, 1, 2)\n", " return 2.0 * image - 1.0\n", "\n", "def vae_embedding(preprocessed, num_samples=5, device=\"cuda\"):\n", " with torch.no_grad():\n", " processed_image = preprocessed.to(device=device)\n", " latent_dist = vae.encode(processed_image).latent_dist\n", " t = [0.18215*latent_dist.sample().to(\"cpu\").squeeze() for i in range(num_samples)] # sample num_samples latent vecs\n", " t = torch.stack(t) # stack them\n", " return torch.mean(t, axis=0).numpy() #average them. output shape: (4,64,64)" ] }, { "cell_type": "code", "execution_count": 3, "id": "6ebd9d84-98f7-4883-ac4b-0ec875b86911", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using custom data configuration SDbiaseval--dataset-cc8e38e46c1acd54\n", "Found cached dataset parquet (/mnt/1da05489-3812-4f15-a6e5-c8d3c57df39e/cache/huggingface/SDbiaseval___parquet/SDbiaseval--dataset-cc8e38e46c1acd54/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f184861d2e2749c9b7c1c1ea3910be27", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/1 [00:00