{ "cells": [ { "cell_type": "markdown", "id": "fe83bcc2", "metadata": {}, "source": [ "![image](/Users/ludovicaschaerf/Desktop/latent-space-theories/data/stylegan3.webp)" ] }, { "cell_type": "code", "execution_count": null, "id": "3722712c", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline \n", "\n", "import pandas as pd\n", "import pickle\n", "import random\n", "\n", "from PIL import Image, ImageColor\n", "import matplotlib.pyplot as plt\n", "\n", "import numpy as np\n", "import torch\n", "\n", "from backend.disentangle_concepts import *\n", "from backend.color_annotations import *\n", "from backend.networks_stylegan3 import *\n", "import dnnlib \n", "import legacy\n", "\n", "import random\n", "\n", "from sklearn.linear_model import LinearRegression, LogisticRegression\n", "\n", "\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": null, "id": "fe7acfaf-dc61-4211-9c78-8e4433bc9deb", "metadata": {}, "outputs": [], "source": [ "annotations_file = './data/textile_annotated_files/seeds0000-100000.pkl'\n", "with open(annotations_file, 'rb') as f:\n", " annotations = pickle.load(f)\n", "\n", "ann_df = pd.read_csv('./data/textile_annotated_files/top_three_colours.csv').fillna('#000000')\n", "\n", "with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:\n", " model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0e4b656e", "metadata": {}, "outputs": [], "source": [ "z = torch.from_numpy(annotations['w_vectors'][0].copy()).to('cpu')\n", "W = z.expand((16, -1)).unsqueeze(0)\n", "img = model.synthesis(W, noise_mode='const')\n", "img.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "1259f950", "metadata": {}, "outputs": [], "source": [ "in_ = model.synthesis.input(W[0, 0].unsqueeze(0))\n", "l1 = model.synthesis.L0_36_512(in_, W[0, 1].unsqueeze(0))\n", "l1.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "918feb0e", "metadata": {}, "outputs": [], "source": [ "a = 'L0_36_512'\n", "getattr(model.synthesis, a)" ] }, { "cell_type": "code", "execution_count": null, "id": "1bf7bfa4", "metadata": {}, "outputs": [], "source": [ "def rest_from_style(x, styles, layer):\n", " dtype = torch.float16 if (getattr(model.synthesis, layer).use_fp16 and device=='cuda') else torch.float32\n", " if getattr(model.synthesis, layer).is_torgb:\n", " print(layer, getattr(model.synthesis, layer).is_torgb)\n", " weight_gain = 1 / np.sqrt(getattr(model.synthesis, layer).in_channels * (getattr(model.synthesis, layer).conv_kernel ** 2))\n", " styles = styles * weight_gain\n", " input_gain = getattr(model.synthesis, layer).magnitude_ema.rsqrt().to(dtype)\n", " # Execute modulated conv2d.\n", " x = modulated_conv2d(x=x.to(dtype), w=getattr(model.synthesis, layer).weight.to(dtype), s=styles.to(dtype),\n", " padding=getattr(model.synthesis, layer).conv_kernel-1, demodulate=(not getattr(model.synthesis, layer).is_torgb), input_gain=input_gain.to(dtype))\n", " # Execute bias, filtered leaky ReLU, and clamping.\n", " gain = 1 if getattr(model.synthesis, layer).is_torgb else np.sqrt(2)\n", " slope = 1 if getattr(model.synthesis, layer).is_torgb else 0.2\n", " x = filtered_lrelu.filtered_lrelu(x=x, fu=getattr(model.synthesis, layer).up_filter, fd=getattr(model.synthesis, layer).down_filter, \n", " b=getattr(model.synthesis, layer).bias.to(x.dtype),\n", " up=getattr(model.synthesis, layer).up_factor, down=getattr(model.synthesis, layer).down_factor, \n", " padding=getattr(model.synthesis, layer).padding,\n", " gain=gain, slope=slope, clamp=getattr(model.synthesis, layer).conv_clamp)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "id": "c674780d", "metadata": {}, "outputs": [], "source": [ "x1 = rest_from_style(in_, model.synthesis.L0_36_512.affine(W[0, 1].unsqueeze(0)), 'L0_36_512')\n", "x1.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "0305ce16", "metadata": {}, "outputs": [], "source": [ "def getS(w):\n", " w_torch = torch.from_numpy(w).to('cpu')\n", " W = w_torch.expand((16, -1)).unsqueeze(0)\n", " s = []\n", " s.append(model.synthesis.input.affine(W[0, 0].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L0_36_512.affine(W[0, 1].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L1_36_512.affine(W[0, 2].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L2_36_512.affine(W[0, 3].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L3_52_512.affine(W[0, 4].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L4_52_512.affine(W[0, 5].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L5_84_512.affine(W[0, 6].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L6_84_512.affine(W[0, 7].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L7_148_512.affine(W[0, 8].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L8_148_512.affine(W[0, 9].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L9_148_362.affine(W[0, 10].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L10_276_256.affine(W[0, 11].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L11_276_181.affine(W[0, 12].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L12_276_128.affine(W[0, 13].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L13_256_128.affine(W[0, 14].unsqueeze(0)).numpy())\n", " s.append(model.synthesis.L14_256_3.affine(W[0, 15].unsqueeze(0)).numpy())\n", " return s" ] }, { "cell_type": "code", "execution_count": null, "id": "981f5215", "metadata": {}, "outputs": [], "source": [ "s = getS(annotations['w_vectors'][0])" ] }, { "cell_type": "code", "execution_count": null, "id": "389ad35a", "metadata": {}, "outputs": [], "source": [ "shapes = [512] + [x.shape[1] for x in s]\n", "layers = ['w', 'input', 'L0_36_512', 'L1_36_512', 'L2_36_512', 'L3_52_512', 'L4_52_512', 'L5_84_512', 'L6_84_512',\n", " 'L7_148_512', 'L8_148_512', 'L9_148_362', 'L10_276_256', 'L11_276_181', 'L12_276_128', 'L13_256_128',\n", " 'L14_256_3']\n", "sum(shapes), shapes" ] }, { "cell_type": "code", "execution_count": null, "id": "3c143e86", "metadata": {}, "outputs": [], "source": [ "def generate_flexible_images(w, change_vectors, lambdas=1, device='cpu'):\n", " w_torch = torch.from_numpy(w).to('cpu')\n", " # w_torch = w_torch + lambdas * change_vectors[0]\n", " W = w_torch.expand((16, -1)).unsqueeze(0)\n", " \n", " x = model.synthesis.input(W[0,0].unsqueeze(0))\n", " for i, layer in enumerate(layers):\n", " if i < 2:\n", " continue\n", " style = getattr(model.synthesis, layer).affine(W[0, i-1].unsqueeze(0))\n", " change = torch.from_numpy(change_vectors[i].copy()).unsqueeze(0).to(device)\n", " style = torch.add(style, change, alpha=lambdas)\n", " x = rest_from_style(x, style, layer)\n", " \n", " if model.synthesis.output_scale != 1:\n", " x = x * model.synthesis.output_scale\n", "\n", " img = (x.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)\n", " img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')\n", " \n", " return img" ] }, { "cell_type": "code", "execution_count": null, "id": "f03915ff", "metadata": {}, "outputs": [], "source": [ "def get_original_pos(top_positions):\n", " current_idx = 0\n", " vectors = []\n", " for i, (leng, layer) in enumerate(zip(shapes, layers)):\n", " arr = np.zeros(leng)\n", " for top_position in top_positions:\n", " if top_position >= current_idx and top_position < current_idx + leng:\n", " arr[top_position - current_idx] = 1\n", " arr = arr / (np.linalg.norm(arr) + 0.000001)\n", " vectors.append(arr)\n", " current_idx += leng\n", " return vectors \n" ] }, { "cell_type": "code", "execution_count": null, "id": "e76d836d", "metadata": {}, "outputs": [], "source": [ "ss = []\n", "for i in tqdm(range(len(annotations['w_vectors']))):\n", " ss.append(getS(annotations['w_vectors'][i]))\n", " \n", "annotations['s_vectors'] = ss" ] }, { "cell_type": "code", "execution_count": null, "id": "6ea1ca59", "metadata": {}, "outputs": [], "source": [ "annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'\n", "with open(annotations_file, 'wb') as f:\n", " pickle.dump(annotations, f)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "12f78bdb", "metadata": {}, "outputs": [], "source": [ "len(ss)" ] }, { "cell_type": "code", "execution_count": null, "id": "cd114cb1", "metadata": {}, "outputs": [], "source": [ "ann_df = tohsv(ann_df)\n", "ann_df.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "0d470f83", "metadata": {}, "outputs": [], "source": [ "def getX(annotations, space='s'):\n", " if space == 'x':\n", " X = np.array(annotations['w_vectors']).reshape((len(annotations['w_vectors']), 512))\n", " elif space == 's':\n", " concat_v = []\n", " for i in range(len(annotations['w_vectors'])):\n", " concat_v.append(np.concatenate([annotations['w_vectors'][i]] + annotations['s_vectors'][i], axis=1))\n", " \n", " X = np.array(concat_v)\n", " X = X[:, 0, :]\n", " print(X.shape)\n", " \n", " return X\n", " \n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "feb64168", "metadata": {}, "outputs": [], "source": [ "X = getX(annotations)\n", "print(X.shape)\n", "y_h = np.array(ann_df['H1'].values)\n", "y_s = np.array(ann_df['S1'].values)\n", "y_v = np.array(ann_df['S1'].values)" ] }, { "cell_type": "code", "execution_count": null, "id": "afa0c100", "metadata": {}, "outputs": [], "source": [ "colors_list = ['Warm Pink Red', 'Red Orange', 'Orange Yellow', 'Gold Yellow', 'Chartreuse Green',\n", " 'Kelly Green', 'Green Blue Seafoam', 'Blue Green Cyan',\n", " 'Warm Blue', 'Indigo Blue Purple', 'Purple Magenta', 'Magenta Pink']" ] }, { "cell_type": "markdown", "id": "39a5668a", "metadata": {}, "source": [ "double check colori" ] }, { "cell_type": "code", "execution_count": null, "id": "5f2b48c0", "metadata": {}, "outputs": [], "source": [ "print([int(x*256/12) if x<12 else 255 for x in range(13)])\n", "y_h_cat = pd.cut(y_h,bins=[x*256/12 if x<12 else 256 for x in range(13)],labels=colors_list).fillna('Warm Pink Red')\n", "\n", "print(y_h_cat.value_counts(dropna=False))\n", "x_trainhc, x_valhc, y_trainhc, y_valhc = train_test_split(X, y_h_cat, test_size=0.2)" ] }, { "cell_type": "markdown", "id": "6c2a1765", "metadata": {}, "source": [ "### Variance based" ] }, { "cell_type": "code", "execution_count": null, "id": "2be4202e", "metadata": {}, "outputs": [], "source": [ "positives = x_trainhc[np.where(y_trainhc == 'Warm Blue')]\n", "print(positives.shape, x_trainhc.shape)\n", "variations = detect_attribute_specific_channels(positives, x_trainhc, sign=True)\n", "print(variations.shape, np.argmax(variations))" ] }, { "cell_type": "code", "execution_count": null, "id": "7d0c129d", "metadata": {}, "outputs": [], "source": [ "argsorted_vars = np.argsort(variations)[-5:]\n", "sorted_vars = np.sort(variations)[-5:]\n", "argsorted_vars, sorted_vars" ] }, { "cell_type": "code", "execution_count": null, "id": "e2c2ed49", "metadata": {}, "outputs": [], "source": [ "original_pos = get_original_pos(argsorted_vars)" ] }, { "cell_type": "code", "execution_count": null, "id": "82e30f0c", "metadata": {}, "outputs": [], "source": [ "seed = random.randint(0,100000)\n", "seed = 52722\n", "original_image_vec = annotations['w_vectors'][seed]\n", "img = generate_original_image(original_image_vec, model, latent_space='W')\n", "img" ] }, { "cell_type": "code", "execution_count": null, "id": "cd71f2c8", "metadata": {}, "outputs": [], "source": [ "device = 'cpu'\n", "img1 = generate_flexible_images(original_image_vec, original_pos, lambdas=-1)\n", "img1" ] }, { "cell_type": "code", "execution_count": null, "id": "abc5ac3f", "metadata": {}, "outputs": [], "source": [ "img1 = generate_flexible_images(original_image_vec, original_pos, lambdas=-2)\n", "img1" ] }, { "cell_type": "code", "execution_count": null, "id": "0602dcab", "metadata": {}, "outputs": [], "source": [ "len(original_pos)" ] }, { "cell_type": "code", "execution_count": null, "id": "d7eb412f", "metadata": {}, "outputs": [], "source": [ "img1 = generate_flexible_images(original_image_vec, original_pos, lambdas=1)\n", "img1" ] }, { "cell_type": "code", "execution_count": null, "id": "03161270", "metadata": {}, "outputs": [], "source": [ "seps, vals = all_variance_based_disentanglements(colors_list, x_trainhc, y_trainhc, k=10, sign=True, space='s')\n", "vals[2].shape" ] }, { "cell_type": "code", "execution_count": null, "id": "ae1016d6", "metadata": {}, "outputs": [], "source": [ "warm_pink_val = get_verification_score(0, seps[0], model, annotations, samples=10, latent_space='W')\n", "warm_pink_val" ] }, { "cell_type": "code", "execution_count": null, "id": "b412cb25", "metadata": {}, "outputs": [], "source": [ "warm_blue_val = get_verification_score(8, seps[8], model, annotations, samples=10, latent_space='W')\n", "warm_blue_val" ] }, { "cell_type": "code", "execution_count": null, "id": "6812cb6b", "metadata": {}, "outputs": [], "source": [ "seps, _ = all_variance_based_disentanglements(colors_list, x_trainhc, y_trainhc, k=10, sign=True)\n", "\n", "for sep, color in zip(seps, colors_list):\n", " images, lambdas = regenerate_images(model, original_image_vec, sep, min_epsilon=-(int(4)), max_epsilon=int(4), count=5, latent_space='W')\n", " fig, axs = plt.subplots(1, len(images), figsize=(50,10))\n", " fig.suptitle(color, fontsize=20)\n", " for i,im in enumerate(images):\n", " axs[i].imshow(im)\n", " axs[i].set_title(np.round(lambdas[i], 2))\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "24b8f275", "metadata": {}, "outputs": [], "source": [ "seps = all_variance_based_disentanglements(colors_list, x_trainhc, y_trainhc, k=10, sign=True)\n", "\n", "for sep, color in zip(seps, colors_list):\n", " images, lambdas = regenerate_images(model, original_image_vec, sep, min_epsilon=-(int(4)), max_epsilon=int(4), count=5, latent_space='W')\n", " fig, axs = plt.subplots(1, len(images), figsize=(50,10))\n", " fig.suptitle(color, fontsize=20)\n", " for i,im in enumerate(images):\n", " axs[i].imshow(im)\n", " axs[i].set_title(np.round(lambdas[i], 2))\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "b6c61fbb", "metadata": {}, "outputs": [], "source": [ "separation_vector_onehot = np.zeros(512)\n", "separation_vector_onehot[argsorted_vars] = 1\n", "\n", "images, lambdas = regenerate_images(model, original_image_vec, separation_vector_onehot, min_epsilon=-(int(10)), max_epsilon=int(10), count=7, latent_space='W')\n", "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n", "for i,im in enumerate(images):\n", " axs[i].imshow(im)\n", " axs[i].set_title(np.round(lambdas[i], 2))" ] }, { "cell_type": "code", "execution_count": null, "id": "7c19e820", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 5 }