aswin-raghavan commited on
Commit
731e661
·
1 Parent(s): d378ca4

init working on depth demo

Browse files
multimodal_domain_adaptation_using_HD.ipynb ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "Requirement already satisfied: matplotlib in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (3.5.1)\n",
13
+ "Requirement already satisfied: seaborn in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (0.11.2)\n",
14
+ "Requirement already satisfied: scikit-learn in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (1.0.2)\n",
15
+ "Requirement already satisfied: numpy in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (1.21.6)\n",
16
+ "Requirement already satisfied: pandas in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (1.3.5)\n",
17
+ "Requirement already satisfied: pillow in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (9.0.0)\n",
18
+ "Requirement already satisfied: transformers in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (4.24.0)\n",
19
+ "Requirement already satisfied: python-dateutil>=2.7 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from matplotlib) (2.8.2)\n",
20
+ "Requirement already satisfied: packaging>=20.0 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from matplotlib) (21.3)\n",
21
+ "Requirement already satisfied: cycler>=0.10 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from matplotlib) (0.11.0)\n",
22
+ "Requirement already satisfied: pyparsing>=2.2.1 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from matplotlib) (2.4.7)\n",
23
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from matplotlib) (1.3.2)\n",
24
+ "Requirement already satisfied: fonttools>=4.22.0 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from matplotlib) (4.28.5)\n",
25
+ "Requirement already satisfied: scipy>=1.0 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from seaborn) (1.7.1)\n",
26
+ "Requirement already satisfied: joblib>=0.11 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from scikit-learn) (1.1.0)\n",
27
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from scikit-learn) (3.1.0)\n",
28
+ "Requirement already satisfied: pytz>=2017.3 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from pandas) (2021.3)\n",
29
+ "Requirement already satisfied: pyyaml>=5.1 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from transformers) (6.0)\n",
30
+ "Requirement already satisfied: filelock in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from transformers) (3.8.0)\n",
31
+ "Requirement already satisfied: regex!=2019.12.17 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from transformers) (2022.10.31)\n",
32
+ "Requirement already satisfied: requests in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from transformers) (2.26.0)\n",
33
+ "Requirement already satisfied: tqdm>=4.27 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from transformers) (4.64.0)\n",
34
+ "Requirement already satisfied: importlib-metadata in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from transformers) (4.2.0)\n",
35
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.10.0 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from transformers) (0.16.4)\n",
36
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from transformers) (0.13.2)\n",
37
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from huggingface-hub<1.0,>=0.10.0->transformers) (4.7.1)\n",
38
+ "Requirement already satisfied: fsspec in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from huggingface-hub<1.0,>=0.10.0->transformers) (2022.1.0)\n",
39
+ "Requirement already satisfied: six>=1.5 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from python-dateutil>=2.7->matplotlib) (1.15.0)\n",
40
+ "Requirement already satisfied: zipp>=0.5 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from importlib-metadata->transformers) (3.6.0)\n",
41
+ "Requirement already satisfied: idna<4,>=2.5 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from requests->transformers) (3.2)\n",
42
+ "Requirement already satisfied: certifi>=2017.4.17 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from requests->transformers) (2021.5.30)\n",
43
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from requests->transformers) (1.26.7)\n",
44
+ "Requirement already satisfied: charset-normalizer~=2.0.0 in /Users/e29154/.pyenv/versions/3.7-dev/lib/python3.7/site-packages (from requests->transformers) (2.0.6)\n",
45
+ "\u001b[33mWARNING: You are using pip version 21.3.1; however, version 23.3.2 is available.\n",
46
+ "You should consider upgrading via the '/Users/e29154/.pyenv/versions/3.7-dev/bin/python3.7 -m pip install --upgrade pip' command.\u001b[0m\n"
47
+ ]
48
+ }
49
+ ],
50
+ "source": [
51
+ "!pip install matplotlib seaborn scikit-learn numpy pandas pillow transformers "
52
+ ]
53
+ },
54
+ {
55
+ "cell_type": "code",
56
+ "execution_count": 2,
57
+ "metadata": {},
58
+ "outputs": [
59
+ {
60
+ "name": "stderr",
61
+ "output_type": "stream",
62
+ "text": [
63
+ "2024-01-02 15:59:48.762816: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
64
+ "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
65
+ ]
66
+ }
67
+ ],
68
+ "source": [
69
+ "\n",
70
+ "import numpy as np\n",
71
+ "from numpy.random import MT19937\n",
72
+ "from numpy.random import RandomState, SeedSequence\n",
73
+ "import matplotlib.pyplot as plt\n",
74
+ "import seaborn as sns\n",
75
+ "sns.set_style('whitegrid')\n",
76
+ "rs = RandomState(MT19937(SeedSequence(123456789)))\n",
77
+ "import math\n",
78
+ "import pandas as pd\n",
79
+ "from turtle import title\n",
80
+ "import numpy as np\n",
81
+ "from PIL import Image\n",
82
+ "from transformers import CLIPProcessor, CLIPModel\n",
83
+ "import pandas as pd\n",
84
+ "from glob import glob\n",
85
+ "import random\n",
86
+ "from datetime import datetime\n",
87
+ "import numpy as np\n",
88
+ "from numpy.random import MT19937\n",
89
+ "from numpy.random import RandomState, SeedSequence"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 3,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "clip_model = CLIPModel.from_pretrained(\"openai/clip-vit-base-patch32\") \n",
99
+ "clip_processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-base-patch32\")"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "execution_count": null,
105
+ "metadata": {},
106
+ "outputs": [],
107
+ "source": []
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": 6,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "def make_LUT(nvalues, dims):\n",
116
+ " lut = np.zeros(shape=(nvalues, dims))\n",
117
+ " lut[0, :] = rs.binomial(n=1, p=0.5, size=(dims))\n",
118
+ " for row in range(1, nvalues):\n",
119
+ " lut[row, :] = lut[row-1, :]\n",
120
+ " # flip few randomly\n",
121
+ " rand_idx = rs.choice(dims, size=dims//nvalues, replace=False)\n",
122
+ " lut[row, rand_idx] = 1 - lut[row, rand_idx]\n",
123
+ " assert np.abs(lut[row, :] - lut[row-1, :]).sum() ==dims//nvalues \n",
124
+ " unique_rows = np.unique(lut, axis=0)\n",
125
+ " assert len(unique_rows) == len(lut)\n",
126
+ " return lut"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": 8,
132
+ "metadata": {},
133
+ "outputs": [
134
+ {
135
+ "name": "stdout",
136
+ "output_type": "stream",
137
+ "text": [
138
+ "(256,) -1.0 1.0 val bins\n"
139
+ ]
140
+ }
141
+ ],
142
+ "source": [
143
+ "HYPERDIMS = 1024\n",
144
+ "VALUE_BITS = 8\n",
145
+ "POS_BITS = 9 # CLIP features are 512 dims\n",
146
+ "val_bins = np.linspace(start=-1., stop=1., num=2**VALUE_BITS)\n",
147
+ "print(val_bins.shape, val_bins.min(), val_bins.max(), 'val bins')\n",
148
+ "val_lut = make_LUT(2**VALUE_BITS, HYPERDIMS)\n",
149
+ "assert val_lut.shape[0] == val_bins.shape[0]\n",
150
+ "pos_lut = rs.binomial(n=1, p=0.5, size=(2**POS_BITS, HYPERDIMS))"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "execution_count": 10,
156
+ "metadata": {},
157
+ "outputs": [],
158
+ "source": [
159
+ "def extract_features(image):\n",
160
+ " PIL_image = Image.fromarray(np.uint8(image)).convert('RGB')\n",
161
+ " inputs = clip_processor(text=[\"a photo of a cat\", \"a photo of a dog\"], images=PIL_image, return_tensors=\"pt\", padding=True)\n",
162
+ " outputs = clip_model(**inputs)\n",
163
+ " # print(outputs.image_embeds.shape)\n",
164
+ " return outputs.image_embeds"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": [
173
+ "\n",
174
+ " \n",
175
+ "def quantize_embeds(embeds):\n",
176
+ " assert np.all(embeds >= val_bins[0])\n",
177
+ " assert np.all(embeds <= val_bins[-1])\n",
178
+ " embeds_flat = embeds.flatten()\n",
179
+ "\n",
180
+ " all_pairs_dist = np.abs(embeds_flat[:, np.newaxis] - val_bins[np.newaxis, :])\n",
181
+ " closest_bin = np.argmin(all_pairs_dist, axis=-1)\n",
182
+ " quantized_embeds_flat = val_bins[closest_bin]\n",
183
+ " quantized_embeds = np.reshape(quantized_embeds_flat, embeds.shape)\n",
184
+ " closest_bin = np.reshape(closest_bin, embeds.shape)\n",
185
+ " print(closest_bin.shape, 'values are in bins', closest_bin.min(), 'to', closest_bin.max())\n",
186
+ " print('abs quant error avg', np.abs(embeds - quantized_embeds).mean())\n",
187
+ " return quantized_embeds, closest_bin\n",
188
+ "\n",
189
+ "def update_exemplars(df, rng, exemplars, lut):\n",
190
+ " embeds = np.array(df['image_embed'].values.tolist()) # df[['image_embed']].to_numpy()\n",
191
+ " labels = np.array(df['label'].values.tolist(), 'int')\n",
192
+ " # print(labels, labels.shape)\n",
193
+ " assert np.all(np.unique(labels) == [0, 1])\n",
194
+ " labels_zero_idx = (labels == 0).nonzero()[0]\n",
195
+ " labels_one_idx = (labels == 1).nonzero()[0]\n",
196
+ " print(labels_zero_idx.shape, \" zeros and \", labels_one_idx.shape, \" ones\")\n",
197
+ " # 70-30 split\n",
198
+ " labels_zero_train_idx = rng[0].choice(labels_zero_idx, size=int(.7 * len(labels_zero_idx)), replace=False)\n",
199
+ " labels_one_train_idx = rng[0].choice(labels_one_idx, size=int(.7 * len(labels_one_idx)), replace=False)\n",
200
+ " embeds_train = np.concatenate([embeds[labels_zero_train_idx], embeds[labels_one_train_idx]], axis=0)\n",
201
+ " labels_train = np.concatenate([labels[labels_zero_train_idx], labels[labels_one_train_idx]], axis=0)\n",
202
+ " print('Training set ', embeds_train.shape, labels_train.shape)\n",
203
+ " print(np.sum(labels_train == 0), \" zeros and \", np.sum(labels_train == 1).sum(), \" ones\")\n",
204
+ " labels_zero_test_idx = np.setdiff1d(labels_zero_idx, labels_zero_train_idx)\n",
205
+ " labels_one_test_idx = np.setdiff1d(labels_one_idx, labels_one_train_idx)\n",
206
+ " embeds_test = np.concatenate([embeds[labels_zero_test_idx], embeds[labels_one_test_idx]], axis=0)\n",
207
+ " labels_test = np.concatenate([labels[labels_zero_test_idx], labels[labels_one_test_idx]], axis=0)\n",
208
+ " print('Test set ', embeds_test.shape, labels_test.shape)\n",
209
+ "\n",
210
+ " quantized_embeds, closest_bin = quantize_embeds(embeds_train)\n",
211
+ " # closest bin is nexample X 512\n",
212
+ " # lut[0] is nvals X dims\n",
213
+ " # hd_embeds in nexample x 512 x dims\n",
214
+ " hd_embeds_per_pos = lut[0][closest_bin]\n",
215
+ " # bundle along pos dimension 512\n",
216
+ " # lut[1] is 512 x dims\n",
217
+ " xor = lambda a,b: a*(1.-b) + b*(1.-a)\n",
218
+ " hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos)\n",
219
+ " hd_embeds = np.sum(hd_embeds, axis=1) / embeds_train.shape[-1]\n",
220
+ " hd_embeds[hd_embeds >= 0.5] = 1.\n",
221
+ " hd_embeds[hd_embeds < 0.5] = 0.\n",
222
+ " # hd_embeds_integer is nexample x dims\n",
223
+ " \n",
224
+ " exemplars_integer = [None, None]\n",
225
+ " exemplars_integer[0] = np.sum(hd_embeds[labels_train == 0], axis=0)\n",
226
+ " exemplars_integer[1] = np.sum(hd_embeds[labels_train == 1], axis=0)\n",
227
+ " exemplars[0] = exemplars_integer[0] / np.sum(labels_train == 0)\n",
228
+ " exemplars[1] = exemplars_integer[1] / np.sum(labels_train == 1)\n",
229
+ " exemplars[0][exemplars[0] >= 0.5] = 1.\n",
230
+ " exemplars[0][exemplars[0] < 0.5] = 0.\n",
231
+ " exemplars[1][exemplars[1] >= 0.5] = 1.\n",
232
+ " exemplars[1][exemplars[1] < 0.5] = 0.\n",
233
+ " print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum())\n",
234
+ " preds = np.zeros(hd_embeds.shape[0])\n",
235
+ " dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1)\n",
236
+ " dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1)\n",
237
+ " preds[dist_to_ex1 < dist_to_ex0] = 1\n",
238
+ " print(preds.shape, labels_train.shape, np.sum(preds == labels_train))\n",
239
+ " train_acc = np.sum(preds == labels_train) / len(labels_train)\n",
240
+ " rng, test_acc = score(embeds_test, labels_test, rng, exemplars, lut)\n",
241
+ " return rng, exemplars, train_acc, test_acc\n",
242
+ "\n",
243
+ "def score(embeds, labels, rng, exemplars, lut):\n",
244
+ " quantized_embeds, closest_bin = quantize_embeds(embeds)\n",
245
+ " # closest bin is nexample X 512\n",
246
+ " # lut[0] is nvals X dims\n",
247
+ " # hd_embeds in nexample x 512 x dims\n",
248
+ " hd_embeds_per_pos = lut[0][closest_bin]\n",
249
+ " # bundle along pos dimension 512\n",
250
+ " # lut[1] is 512 x dims\n",
251
+ " xor = lambda a,b: a*(1.-b) + b*(1.-a)\n",
252
+ " hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos)\n",
253
+ " hd_embeds = np.sum(hd_embeds, axis=1) / embeds.shape[-1]\n",
254
+ " hd_embeds[hd_embeds >= 0.5] = 1.\n",
255
+ " hd_embeds[hd_embeds < 0.5] = 0.\n",
256
+ " # hd_embeds_integer is nexample x dims\n",
257
+ " print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum())\n",
258
+ " preds = np.zeros(hd_embeds.shape[0])\n",
259
+ " dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1)\n",
260
+ " dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1)\n",
261
+ " preds[dist_to_ex1 < dist_to_ex0] = 1\n",
262
+ " print(preds.shape, labels.shape, np.sum(preds == labels), len(labels))\n",
263
+ " acc = np.sum(preds == labels) / len(labels)\n",
264
+ " return rng, acc\n",
265
+ "\n",
266
+ "def predict(embeds, exemplars, lut):\n",
267
+ " quantized_embeds, closest_bin = quantize_embeds(embeds)\n",
268
+ " # closest bin is nexample X 512\n",
269
+ " # lut[0] is nvals X dims\n",
270
+ " # hd_embeds in nexample x 512 x dims\n",
271
+ " hd_embeds_per_pos = lut[0][closest_bin]\n",
272
+ " # bundle along pos dimension 512\n",
273
+ " # lut[1] is 512 x dims\n",
274
+ " xor = lambda a,b: a*(1.-b) + b*(1.-a)\n",
275
+ " hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos)\n",
276
+ " hd_embeds = np.sum(hd_embeds, axis=1) / embeds.shape[-1]\n",
277
+ " hd_embeds[hd_embeds >= 0.5] = 1.\n",
278
+ " hd_embeds[hd_embeds < 0.5] = 0.\n",
279
+ " # hd_embeds_integer is nexample x dims\n",
280
+ " # print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum())\n",
281
+ " dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1)\n",
282
+ " dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1)\n",
283
+ " print('dists', dist_to_ex0, dist_to_ex1)\n",
284
+ " odds = np.abs(dist_to_ex0 - dist_to_ex1).item()\n",
285
+ " if dist_to_ex1 < dist_to_ex0:\n",
286
+ " preds = np.array([1., odds])\n",
287
+ " else:\n",
288
+ " preds = np.array([odds, 1.])\n",
289
+ " print(preds)\n",
290
+ " # preds = np.array([-1. * dist_to_ex0, -1. * dist_to_ex1])\n",
291
+ " preds = preds / preds.sum()\n",
292
+ " # print(preds.shape)\n",
293
+ " print(preds)\n",
294
+ " return {\"👍\": preds[1], \"👎\": preds[0]}"
295
+ ]
296
+ }
297
+ ],
298
+ "metadata": {
299
+ "kernelspec": {
300
+ "display_name": "midas-py310",
301
+ "language": "python",
302
+ "name": "python3"
303
+ },
304
+ "language_info": {
305
+ "codemirror_mode": {
306
+ "name": "ipython",
307
+ "version": 3
308
+ },
309
+ "file_extension": ".py",
310
+ "mimetype": "text/x-python",
311
+ "name": "python",
312
+ "nbconvert_exporter": "python",
313
+ "pygments_lexer": "ipython3",
314
+ "version": "3.7.12+"
315
+ }
316
+ },
317
+ "nbformat": 4,
318
+ "nbformat_minor": 2
319
+ }