{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "3722712c", "metadata": {}, "outputs": [], "source": [ "%matplotlib inline \n", "\n", "import pandas as pd\n", "import pickle\n", "import random\n", "\n", "from PIL import Image, ImageColor\n", "import matplotlib.pyplot as plt\n", "from sklearn.model_selection import train_test_split\n", "\n", "import numpy as np\n", "import torch\n", "\n", "from backend.disentangle_concepts import *\n", "import dnnlib \n", "import legacy\n", "from backend.color_annotations import *\n", "\n", "import random\n", "\n", "from sklearn.linear_model import LinearRegression, LogisticRegression\n", "\n", "\n", "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "markdown", "id": "03efb8c0", "metadata": {}, "source": [ "0-60\n", "\n", "Red\n", "\n", "60-120\n", "\n", "Yellow\n", "\n", "120-180\n", "\n", "Green\n", "\n", "180-240\n", "\n", "Cyan\n", "\n", "240-300\n", "\n", "Blue\n", "\n", "300-360\n", "\n", "Magenta\n", "\n", "Standard classification" ] }, { "cell_type": "code", "execution_count": null, "id": "00a35126", "metadata": {}, "outputs": [], "source": [ "def hex2rgb(hex_value):\n", " h = hex_value.strip(\"#\") \n", " rgb = tuple(int(h[i:i+2], 16) for i in (0, 2, 4))\n", " return rgb\n", "\n", "def rgb2hsv(r, g, b):\n", " # Normalize R, G, B values\n", " r, g, b = r / 255.0, g / 255.0, b / 255.0\n", " \n", " # h, s, v = hue, saturation, value\n", " max_rgb = max(r, g, b) \n", " min_rgb = min(r, g, b) \n", " difference = max_rgb-min_rgb \n", " \n", " # if max_rgb and max_rgb are equal then h = 0\n", " if max_rgb == min_rgb:\n", " h = 0\n", " \n", " # if max_rgb==r then h is computed as follows\n", " elif max_rgb == r:\n", " h = (60 * ((g - b) / difference) + 360) % 360\n", " \n", " # if max_rgb==g then compute h as follows\n", " elif max_rgb == g:\n", " h = (60 * ((b - r) / difference) + 120) % 360\n", " \n", " # if max_rgb=b then compute h\n", " elif max_rgb == b:\n", " h = (60 * ((r - g) / difference) + 240) % 360\n", " \n", " # if max_rgb==zero then s=0\n", " if max_rgb == 0:\n", " s = 0\n", " else:\n", " s = (difference / max_rgb) * 100\n", " \n", " # compute v\n", " v = max_rgb * 100\n", " # return rounded values of H, S and V\n", " return tuple(map(round, (h, s, v)))" ] }, { "cell_type": "code", "execution_count": null, "id": "5630402a", "metadata": {}, "outputs": [], "source": [ "num_colors = 7" ] }, { "cell_type": "code", "execution_count": null, "id": "c8428918", "metadata": {}, "outputs": [], "source": [ "bins = [(x-1) * 360 / (num_colors - 1) if x != 1 \n", " else 1 for x in range(num_colors + 1)]\n", "bins[0] = 0\n", "\n", "bins\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "00e57598", "metadata": {}, "outputs": [], "source": [ "centers = [int((bins[i-1]+bins[i])/2) for i in range(len(bins)) if i > 0]" ] }, { "cell_type": "code", "execution_count": null, "id": "1550ecd7", "metadata": {}, "outputs": [], "source": [ "print(bins)\n", "print(centers)" ] }, { "cell_type": "code", "execution_count": null, "id": "ab9be91e", "metadata": {}, "outputs": [], "source": [ "def create_color_image(hue, saturation, value, size=(20, 10)):\n", " color_rgb = ImageColor.getrgb(\"hsv({}, {}%, {}%)\".format(hue, int(saturation * 100), int(value * 100)))\n", " image = Image.new(\"RGB\", size, color_rgb)\n", " return image" ] }, { "cell_type": "code", "execution_count": null, "id": "bf1c8ab5", "metadata": {}, "outputs": [], "source": [ "def display_image(image, title=''):\n", " plt.figure()\n", " plt.suptitle(title)\n", " plt.imshow(image)\n", " plt.axis('off')\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "519b16d4", "metadata": {}, "outputs": [], "source": [ "# def to_256(val):\n", "# x = val*360/256\n", "# return int(x)" ] }, { "cell_type": "code", "execution_count": null, "id": "8f696758", "metadata": {}, "outputs": [], "source": [ "names = ['Gray', 'Red', 'Yellow', 'Green', 'Cyan', 'Blue','Magenta']" ] }, { "cell_type": "code", "execution_count": null, "id": "50825823", "metadata": {}, "outputs": [], "source": [ "saturation = 1 # Saturation value (0 to 1)\n", "value = 1 # Value (brightness) value (0 to 1)\n", "for hue, name in zip(centers, names[:num_colors]):\n", " image = create_color_image(hue, saturation, value)\n", " display_image(image, name) # Display the generated color image" ] }, { "cell_type": "code", "execution_count": null, "id": "fe7acfaf-dc61-4211-9c78-8e4433bc9deb", "metadata": {}, "outputs": [], "source": [ "annotations_file = './data/textile_annotated_files/seeds0000-100000.pkl'\n", "with open(annotations_file, 'rb') as f:\n", " annotations = pickle.load(f)\n", "\n", "ann_df = pd.read_csv('./data/textile_annotated_files/top_three_colours.csv').fillna('#000000')\n", "\n", "with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:\n", " model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore\n" ] }, { "cell_type": "code", "execution_count": null, "id": "065cd656", "metadata": {}, "outputs": [], "source": [ "from DisentanglementBase import DisentanglementBase" ] }, { "cell_type": "code", "execution_count": null, "id": "afb8a611", "metadata": {}, "outputs": [], "source": [ "variable = 'H1'\n", "disentanglemnet_exp = DisentanglementBase('.', model, annotations, ann_df, space='W', colors_list=names, compute_s=False, variable=variable)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a7398217", "metadata": {}, "outputs": [], "source": [ "ann_df = disentanglemnet_exp.df" ] }, { "cell_type": "code", "execution_count": null, "id": "cd114cb1", "metadata": {}, "outputs": [], "source": [ "ann_df.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "feb64168", "metadata": {}, "outputs": [], "source": [ "X = np.array(annotations['w_vectors']).reshape((len(annotations['w_vectors']), 512))\n", "print(X.shape)\n", "y_h = np.array(ann_df['H1'].values)\n", "y_s = np.array(ann_df['S1'].values)\n", "y_v = np.array(ann_df['S1'].values)" ] }, { "cell_type": "code", "execution_count": null, "id": "0ca08749", "metadata": {}, "outputs": [], "source": [ "np.unique(y_h)" ] }, { "cell_type": "markdown", "id": "e8f33f14", "metadata": {}, "source": [ "## Regression model" ] }, { "cell_type": "code", "execution_count": null, "id": "8da0a43d", "metadata": {}, "outputs": [], "source": [ "x_trainh, x_valh, y_trainh, y_valh = train_test_split(X, y_h, test_size=0.2)\n", "x_trains, x_vals, y_trains, y_vals = train_test_split(X, y_s, test_size=0.2)\n", "x_trainv, x_valv, y_trainv, y_valv = train_test_split(X, y_v, test_size=0.2)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "8eddba20", "metadata": {}, "outputs": [], "source": [ "regh = LinearRegression().fit(x_trainh, y_trainh)\n", "print('Val performance logistic regression', np.round(regh.score(x_valh, y_valh),2))\n", "\n", "separation_vectorh = regh.coef_ / np.linalg.norm(regh.coef_)\n", "print(separation_vectorh.shape)\n", "\n", "regs = LinearRegression().fit(x_trains, y_trains)\n", "print('Val performance logistic regression', np.round(regs.score(x_vals, y_vals),2))\n", "\n", "separation_vectors = regs.coef_ / np.linalg.norm(regs.coef_)\n", "print(separation_vectors.shape)\n", "\n", "regv = LinearRegression().fit(x_trainv, y_trainv)\n", "print('Val performance logistic regression', np.round(reg.score(x_valv, y_valv),2))\n", "\n", "separation_vectorv = regv.coef_ / np.linalg.norm(regv.coef_)\n", "print(separation_vectorv.shape)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "c6a63345", "metadata": {}, "outputs": [], "source": [ "seed = random.randint(0,100000)\n", "original_image_vec = annotations['w_vectors'][seed]\n", "img = generate_original_image(original_image_vec, model, latent_space='W')\n", "img" ] }, { "cell_type": "code", "execution_count": null, "id": "09f13e6a", "metadata": {}, "outputs": [], "source": [ "images, lambdas = regenerate_images(model, original_image_vec, separation_vectors, min_epsilon=-(int(5)), max_epsilon=int(5), count=7, latent_space='W')" ] }, { "cell_type": "code", "execution_count": null, "id": "c66bcdde", "metadata": {}, "outputs": [], "source": [ "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n", "for i,im in enumerate(images):\n", " axs[i].imshow(im)\n", " axs[i].set_title(np.round(lambdas[i], 2))" ] }, { "cell_type": "markdown", "id": "4c44f0dd", "metadata": {}, "source": [ "fourier per regolarità pattern\n", "linear correlation con il colore\n", "distribution dei colori original e non \n", "neural network per vedere quanto riesce a classificare" ] }, { "cell_type": "markdown", "id": "c2790c25", "metadata": {}, "source": [ "## Multiclass model" ] }, { "cell_type": "code", "execution_count": null, "id": "afa0c100", "metadata": {}, "outputs": [], "source": [ "colors_list = names" ] }, { "cell_type": "markdown", "id": "39a5668a", "metadata": {}, "source": [ "double check colori" ] }, { "cell_type": "code", "execution_count": null, "id": "5f2b48c0", "metadata": {}, "outputs": [], "source": [ "from sklearn import svm\n", "\n", "y_h_cat = pd.cut(y_h,bins=bins,labels=colors_list, include_lowest=True)\n", "\n", "print(y_h_cat.value_counts(dropna=False))\n", "\n", "y_h_cat[y_s == 0] = 'Gray'\n", "y_h_cat[y_s == 100] = 'Gray'\n", "y_h_cat[y_v == 0] = 'Gray'\n", "y_h_cat[y_v == 100] = 'Gray'\n", "\n", "print(y_h_cat.value_counts(dropna=False))\n", "\n", "x_trainhc, x_valhc, y_trainhc, y_valhc = train_test_split(X, y_h_cat, test_size=0.2)" ] }, { "cell_type": "markdown", "id": "67651454", "metadata": {}, "source": [ "### SVR and LR" ] }, { "cell_type": "code", "execution_count": null, "id": "7804f593", "metadata": {}, "outputs": [], "source": [ "clf = svm.LinearSVC().fit(x_trainhc, y_trainhc)\n", "print('Val performance SVR regression', np.round(clf.score(x_valhc, y_valhc),2))" ] }, { "cell_type": "code", "execution_count": null, "id": "e6e31b75", "metadata": {}, "outputs": [], "source": [ "clf_log = LogisticRegression(multi_class='ovr').fit(x_trainhc, y_trainhc)\n", "print('Val performance logistic regression', np.round(clf_log.score(x_valhc, y_valhc),2))" ] }, { "cell_type": "code", "execution_count": null, "id": "82e30f0c", "metadata": {}, "outputs": [], "source": [ "seed = random.randint(0,100000)\n", "original_image_vec = annotations['w_vectors'][seed]\n", "img = generate_original_image(original_image_vec, model, latent_space='W')\n", "img" ] }, { "cell_type": "code", "execution_count": null, "id": "c8ce6086", "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import accuracy_score, confusion_matrix \n", "\n", "y_predhc = clf.predict(x_valhc)\n", "print(y_predhc, y_valhc)\n", "accuracy_score(y_valhc, y_predhc,)\n", "\n", "\n", "#Get the confusion matrix\n", "cm = confusion_matrix(y_valhc, y_predhc)\n", "#array([[1, 0, 0],\n", "# [1, 0, 0],\n", "# [0, 1, 2]])\n", "\n", "#Now the normalize the diagonal entries\n", "cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", "#array([[1. , 0. , 0. ],\n", "# [1. , 0. , 0. ],\n", "# [0. , 0.33333333, 0.66666667]])\n", "\n", "#The diagonal entries are the accuracies of each class\n", "cm.diagonal()\n", "#array([1. , 0. , 0.66666667])" ] }, { "cell_type": "code", "execution_count": null, "id": "112f4b87", "metadata": {}, "outputs": [], "source": [ "print(clf.coef_, clf.coef_.shape)" ] }, { "cell_type": "code", "execution_count": null, "id": "6241bce1", "metadata": {}, "outputs": [], "source": [ "warm_blue = clf.coef_[-3, :] / np.linalg.norm(clf.coef_[-3, :])\n", "\n", "images, lambdas = regenerate_images(model, original_image_vec, warm_blue, min_epsilon=-(int(5)), max_epsilon=int(5), count=7, latent_space='W')\n", "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n", "for i,im in enumerate(images):\n", " axs[i].imshow(im)\n", " axs[i].set_title(np.round(lambdas[i], 2))" ] }, { "cell_type": "code", "execution_count": null, "id": "2fefcf0c", "metadata": {}, "outputs": [], "source": [ "warm_blue = clf.coef_[-4, :] / np.linalg.norm(clf.coef_[-4, :])\n", "\n", "images, lambdas = regenerate_images(model, original_image_vec, warm_blue, min_epsilon=-(int(50)), max_epsilon=int(50), count=2, latent_space='W')\n", "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n", "for i,im in enumerate(images):\n", " axs[i].imshow(im)\n", " axs[i].set_title(np.round(lambdas[i], 2))" ] }, { "cell_type": "code", "execution_count": null, "id": "e0f31e0b", "metadata": {}, "outputs": [], "source": [ "from sklearn import svm\n", "\n", "y_h_cat = pd.cut(y_h,bins=[x*256/6 if x<6 else 256 for x in range(7)],labels=['Red', 'Yellow', 'Green', 'Blue',\n", " 'Purple', 'Pink']).fillna('Red')\n", "\n", "print(y_h_cat.value_counts(dropna=False))\n", "x_trainhc, x_valhc, y_trainhc, y_valhc = train_test_split(X, y_h_cat, test_size=0.2)\n", "\n", "clf6 = svm.LinearSVC().fit(x_trainhc, y_trainhc)\n", "print('Val performance logistic regression', np.round(clf6.score(x_valhc, y_valhc),2))\n" ] }, { "cell_type": "code", "execution_count": null, "id": "f5f28b41", "metadata": {}, "outputs": [], "source": [ "warm_blue = clf6.coef_[1, :] / np.linalg.norm(clf6.coef_[1, :])\n", "\n", "images, lambdas = regenerate_images(model, original_image_vec, warm_blue, min_epsilon=-(int(10)), max_epsilon=int(10), count=7, latent_space='W')\n", "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n", "for i,im in enumerate(images):\n", " axs[i].imshow(im)\n", " axs[i].set_title(np.round(lambdas[i], 2))" ] }, { "cell_type": "markdown", "id": "4e0c7808", "metadata": {}, "source": [ "## dimensionality reduction e vediamo dove finiscono i vari colori" ] }, { "cell_type": "markdown", "id": "833ed31f", "metadata": {}, "source": [ "## clustering per vedere quali sono i centroid di questo spazio e se ci sono regioni determinate dai colori" ] }, { "cell_type": "code", "execution_count": null, "id": "7c19e820", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 5 }