File size: 11,300 Bytes
a09a6e6
f30185d
 
 
 
c32f3ac
f30f5b4
fc374dd
9b18ee8
 
7d846e2
 
 
f30f5b4
c32f3ac
 
f30185d
2bdd873
 
7d846e2
d6d2726
 
f30185d
f30f5b4
f30185d
c32f3ac
 
e7c819a
c32f3ac
f30f5b4
 
678899c
ba627dd
7de49b3
c159a9f
d31062e
51d1628
f30f5b4
2d1c9e3
 
6a38f13
95ceb1c
 
 
 
b288621
b7caaef
678899c
0d0d741
678899c
 
0d0d741
678899c
 
0d0d741
7d846e2
d6d2726
7d846e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c387e2f
e2c6c79
c62df5a
 
7d846e2
 
678899c
 
 
 
 
 
 
 
 
 
c387e2f
eb801a5
678899c
 
 
038ac8c
35b7079
e7c819a
fc18f1c
35b7079
 
678899c
 
e7c819a
 
678899c
 
 
eb801a5
678899c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c387e2f
 
 
 
678899c
8975e3e
 
 
 
 
 
e631c99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d846e2
b7caaef
 
 
 
 
 
 
 
 
 
 
 
 
 
e9d5218
b7caaef
 
0ec9df3
d378ca4
2bdd873
f1191ff
2bdd873
f1191ff
63bd7ee
0ec9df3
2bdd873
e9d5218
0ec9df3
e9d5218
b7caaef
7de49b3
a5c6fd6
9b18ee8
 
a5c6fd6
f30f5b4
678899c
 
 
 
 
f30f5b4
c159a9f
 
 
678899c
 
b7caaef
 
27e7a4c
7d90c27
678899c
 
c62df5a
7d846e2
 
678899c
7d846e2
 
678899c
 
 
7d846e2
 
c62df5a
b2c9f07
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
from turtle import title
import gradio as gr
from transformers import pipeline
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import pandas as pd
from glob import glob
import random
from datetime import datetime
import numpy as np
from numpy.random import MT19937
from numpy.random import RandomState, SeedSequence

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

HYPERDIMS = 1024
VALUE_BITS = 8
POS_BITS = 9 # CLIP features are 512 dims
val_bins = np.linspace(start=-1., stop=1., num=2**VALUE_BITS)
print(val_bins.shape, val_bins.min(), val_bins.max(), 'val bins')

def extract_features(image):
    PIL_image = Image.fromarray(np.uint8(image)).convert('RGB')
    inputs = clip_processor(text=["a photo of a cat", "a photo of a dog"], images=PIL_image, return_tensors="pt", padding=True)
    outputs = clip_model(**inputs)
    # print(outputs.image_embeds.shape)
    return outputs.image_embeds


def update_table(img, img_name, df, state, label, exemplars_state, lut_state):
    img_embeds = extract_features(img).detach().numpy().squeeze().tolist()
    print(img_name, img.shape, len(img_embeds), 'images left:', len(state))
    new_df = pd.DataFrame({'image_name': img_name, 'label': label, 'image_embed': None}, columns=['image_name', 'image_embed', 'label'], index=[0])
    # print(new_df)
    new_df.at[0, 'image_embed'] = img_embeds
    df = pd.concat([df, new_df])
    filt = df["image_name"] != ""
    df = df[filt]
    state.pop()
    t = state[-10:]
    random.shuffle(t)
    state = state[:-10] + t
    idx = -1 
    next_img = state[idx]
    preds = predict(extract_features(img).detach().numpy(), exemplars_state, lut_state)
    return next_img, next_img, df, state, preds

def update_table_up(img, img_name, df, state, exemplars_state, lut_state):
    return update_table(img, img_name, df, state, 1, exemplars_state, lut_state)

def update_table_down(img, img_name, df, state, exemplars_state, lut_state):
    return update_table(img, img_name, df, state, 0, exemplars_state, lut_state)

def make_LUT(nvalues, dims, rs):
    lut = np.zeros(shape=(nvalues, dims))
    lut[0, :] = rs.binomial(n=1, p=0.5, size=(dims))
    for row in range(1, nvalues):
        lut[row, :] = lut[row-1, :]
        # flip few randomly
        rand_idx = rs.choice(dims, size=dims//nvalues, replace=False)
        lut[row, rand_idx] = 1 - lut[row, rand_idx]
        assert np.abs(lut[row, :] - lut[row-1, :]).sum() ==dims//nvalues 
    unique_rows = np.unique(lut, axis=0)
    assert len(unique_rows) == len(lut)
    return lut

def load_fn(images, rng_state, exemplars_state, lut_state):
    rs = RandomState(MT19937(SeedSequence(123456789)))
    rng_state[0] = rs
    exemplars_state[0] = rs.binomial(n=1, p=0.5, size=HYPERDIMS)
    exemplars_state[1] = rs.binomial(n=1, p=0.5, size=HYPERDIMS)
    lut_state[0] = make_LUT(2**VALUE_BITS, HYPERDIMS, rs)
    assert lut_state[0].shape[0] == val_bins.shape[0]
    lut_state[1] = rs.binomial(n=1, p=0.5, size=(2**POS_BITS, HYPERDIMS))
    print(exemplars_state)
    print(lut_state[0].shape, lut_state[1].shape)
    return images[-1], images[-1], rng_state, exemplars_state, lut_state

def quantize_embeds(embeds):
    assert np.all(embeds >= val_bins[0])
    assert np.all(embeds <= val_bins[-1])
    embeds_flat = embeds.flatten()

    all_pairs_dist = np.abs(embeds_flat[:, np.newaxis] - val_bins[np.newaxis, :])
    closest_bin = np.argmin(all_pairs_dist, axis=-1)
    quantized_embeds_flat = val_bins[closest_bin]
    quantized_embeds = np.reshape(quantized_embeds_flat, embeds.shape)
    closest_bin = np.reshape(closest_bin, embeds.shape)
    print(closest_bin.shape, 'values are in bins', closest_bin.min(), 'to', closest_bin.max())
    print('abs quant error avg', np.abs(embeds - quantized_embeds).mean())
    return quantized_embeds, closest_bin

def update_exemplars(df, rng, exemplars, lut):
    embeds = np.array(df['image_embed'].values.tolist()) # df[['image_embed']].to_numpy()
    labels = np.array(df['label'].values.tolist(), 'int')
    # print(labels, labels.shape)
    assert np.all(np.unique(labels) == [0, 1])
    labels_zero_idx = (labels == 0).nonzero()[0]
    labels_one_idx = (labels == 1).nonzero()[0]
    print(labels_zero_idx.shape, " zeros and ", labels_one_idx.shape, " ones")
    # 70-30 split
    labels_zero_train_idx = rng[0].choice(labels_zero_idx, size=int(.7 * len(labels_zero_idx)), replace=False)
    labels_one_train_idx = rng[0].choice(labels_one_idx, size=int(.7 * len(labels_one_idx)), replace=False)
    embeds_train = np.concatenate([embeds[labels_zero_train_idx], embeds[labels_one_train_idx]], axis=0)
    labels_train = np.concatenate([labels[labels_zero_train_idx], labels[labels_one_train_idx]], axis=0)
    print('Training set ', embeds_train.shape, labels_train.shape)
    print(np.sum(labels_train == 0), " zeros and ", np.sum(labels_train == 1).sum(), " ones")
    labels_zero_test_idx = np.setdiff1d(labels_zero_idx, labels_zero_train_idx)
    labels_one_test_idx = np.setdiff1d(labels_one_idx, labels_one_train_idx)
    embeds_test = np.concatenate([embeds[labels_zero_test_idx], embeds[labels_one_test_idx]], axis=0)
    labels_test = np.concatenate([labels[labels_zero_test_idx], labels[labels_one_test_idx]], axis=0)
    print('Test set ', embeds_test.shape, labels_test.shape)

    quantized_embeds, closest_bin = quantize_embeds(embeds_train)
    # closest bin is nexample X 512
    # lut[0] is nvals X dims
    # hd_embeds in nexample x 512 x dims
    hd_embeds_per_pos = lut[0][closest_bin]
    # bundle along pos dimension 512
    # lut[1] is 512 x dims
    xor = lambda a,b: a*(1.-b) + b*(1.-a)
    hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos)
    hd_embeds = np.sum(hd_embeds, axis=1) / embeds_train.shape[-1]
    hd_embeds[hd_embeds >= 0.5] = 1.
    hd_embeds[hd_embeds < 0.5] = 0.
    # hd_embeds_integer is nexample x dims
    
    exemplars_integer = [None, None]
    exemplars_integer[0] = np.sum(hd_embeds[labels_train == 0], axis=0)
    exemplars_integer[1] = np.sum(hd_embeds[labels_train == 1], axis=0)
    exemplars[0] = exemplars_integer[0] / np.sum(labels_train == 0)
    exemplars[1] = exemplars_integer[1] / np.sum(labels_train == 1)
    exemplars[0][exemplars[0] >= 0.5] = 1.
    exemplars[0][exemplars[0] < 0.5] = 0.
    exemplars[1][exemplars[1] >= 0.5] = 1.
    exemplars[1][exemplars[1] < 0.5] = 0.
    print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum())
    preds = np.zeros(hd_embeds.shape[0])
    dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1)
    dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1)
    preds[dist_to_ex1 < dist_to_ex0] = 1
    print(preds.shape, labels_train.shape, np.sum(preds == labels_train))
    train_acc = np.sum(preds == labels_train) / len(labels_train)
    rng, test_acc = score(embeds_test, labels_test, rng, exemplars, lut)
    return rng, exemplars, train_acc, test_acc

def score(embeds, labels, rng, exemplars, lut):
    quantized_embeds, closest_bin = quantize_embeds(embeds)
    # closest bin is nexample X 512
    # lut[0] is nvals X dims
    # hd_embeds in nexample x 512 x dims
    hd_embeds_per_pos = lut[0][closest_bin]
    # bundle along pos dimension 512
    # lut[1] is 512 x dims
    xor = lambda a,b: a*(1.-b) + b*(1.-a)
    hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos)
    hd_embeds = np.sum(hd_embeds, axis=1) / embeds.shape[-1]
    hd_embeds[hd_embeds >= 0.5] = 1.
    hd_embeds[hd_embeds < 0.5] = 0.
    # hd_embeds_integer is nexample x dims
    print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum())
    preds = np.zeros(hd_embeds.shape[0])
    dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1)
    dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1)
    preds[dist_to_ex1 < dist_to_ex0] = 1
    print(preds.shape, labels.shape, np.sum(preds == labels), len(labels))
    acc = np.sum(preds == labels) / len(labels)
    return rng, acc

def predict(embeds, exemplars, lut):
    quantized_embeds, closest_bin = quantize_embeds(embeds)
    # closest bin is nexample X 512
    # lut[0] is nvals X dims
    # hd_embeds in nexample x 512 x dims
    hd_embeds_per_pos = lut[0][closest_bin]
    # bundle along pos dimension 512
    # lut[1] is 512 x dims
    xor = lambda a,b: a*(1.-b) + b*(1.-a)
    hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos)
    hd_embeds = np.sum(hd_embeds, axis=1) / embeds.shape[-1]
    hd_embeds[hd_embeds >= 0.5] = 1.
    hd_embeds[hd_embeds < 0.5] = 0.
    # hd_embeds_integer is nexample x dims
    # print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum())
    dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1)
    dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1)
    print('dists', dist_to_ex0, dist_to_ex1)
    odds = np.abs(dist_to_ex0 - dist_to_ex1).item()
    if dist_to_ex1 < dist_to_ex0:
        preds = np.array([1., odds])
    else:
        preds = np.array([odds, 1.])
    print(preds)
    # preds = np.array([-1. * dist_to_ex0, -1. * dist_to_ex1])
    preds = preds / preds.sum()
    # print(preds.shape)
    print(preds)
    return {"๐Ÿ‘": preds[1], "๐Ÿ‘Ž": preds[0]}
 
with gr.Blocks(title="End-User Personalization") as demo:
    img_list = glob('images/**/*.jpg')
    random.seed(datetime.now().timestamp())
    random.shuffle(img_list)
    images = gr.State(img_list)
    # start_button = gr.Button(label="Start")
    with gr.Row():
        image_display = gr.Image()
        with gr.Column():
            image_fname = gr.Textbox()
            preds = gr.Label("Prediction")
    # text_display = gr.Text()
    with gr.Row():
        upvote = gr.Button("๐Ÿ‘")
        downvote = gr.Button("๐Ÿ‘Ž")
    personalize = gr.Button("Personalize")
    with gr.Row():
        train_acc = gr.Textbox(label="Train accuracy")
        test_acc = gr.Textbox(label="Test accuracy")
    annotated_samples = gr.Dataframe(headers=['image_name', 'label', 'image_embed'], row_count=(1, 'dynamic'), 
                                     col_count=(3, 'fixed'), label='Annotations', wrap=False)
    
    
    # HD stuff for incremental updates
    rng = gr.State([None])
    exemplars_state = gr.State([None, None])
    exemplars_state_integer = gr.State([None, None])
    lut_state = gr.State([None, None])

    upvote.click(update_table_up, inputs=[image_display, image_fname, annotated_samples, images, exemplars_state, lut_state], outputs=[image_display, image_fname, annotated_samples, images, preds])
    downvote.click(update_table_down, inputs=[image_display, image_fname, annotated_samples, images, exemplars_state, lut_state], outputs=[image_display, image_fname, annotated_samples, images, preds])
    personalize.click(update_exemplars, [annotated_samples, rng, exemplars_state, lut_state], [rng, exemplars_state, train_acc, test_acc])
    demo.load(load_fn, inputs=[images, rng, exemplars_state, lut_state], outputs=[image_display, image_fname, rng, exemplars_state, lut_state])
    
    
demo.launch(show_error=True, debug=True)