aswin-raghavan
bugfix typo
d378ca4
from turtle import title
import gradio as gr
from transformers import pipeline
import numpy as np
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import pandas as pd
from glob import glob
import random
from datetime import datetime
import numpy as np
from numpy.random import MT19937
from numpy.random import RandomState, SeedSequence
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
HYPERDIMS = 1024
VALUE_BITS = 8
POS_BITS = 9 # CLIP features are 512 dims
val_bins = np.linspace(start=-1., stop=1., num=2**VALUE_BITS)
print(val_bins.shape, val_bins.min(), val_bins.max(), 'val bins')
def extract_features(image):
PIL_image = Image.fromarray(np.uint8(image)).convert('RGB')
inputs = clip_processor(text=["a photo of a cat", "a photo of a dog"], images=PIL_image, return_tensors="pt", padding=True)
outputs = clip_model(**inputs)
# print(outputs.image_embeds.shape)
return outputs.image_embeds
def update_table(img, img_name, df, state, label, exemplars_state, lut_state):
img_embeds = extract_features(img).detach().numpy().squeeze().tolist()
print(img_name, img.shape, len(img_embeds), 'images left:', len(state))
new_df = pd.DataFrame({'image_name': img_name, 'label': label, 'image_embed': None}, columns=['image_name', 'image_embed', 'label'], index=[0])
# print(new_df)
new_df.at[0, 'image_embed'] = img_embeds
df = pd.concat([df, new_df])
filt = df["image_name"] != ""
df = df[filt]
state.pop()
t = state[-10:]
random.shuffle(t)
state = state[:-10] + t
idx = -1
next_img = state[idx]
preds = predict(extract_features(img).detach().numpy(), exemplars_state, lut_state)
return next_img, next_img, df, state, preds
def update_table_up(img, img_name, df, state, exemplars_state, lut_state):
return update_table(img, img_name, df, state, 1, exemplars_state, lut_state)
def update_table_down(img, img_name, df, state, exemplars_state, lut_state):
return update_table(img, img_name, df, state, 0, exemplars_state, lut_state)
def make_LUT(nvalues, dims, rs):
lut = np.zeros(shape=(nvalues, dims))
lut[0, :] = rs.binomial(n=1, p=0.5, size=(dims))
for row in range(1, nvalues):
lut[row, :] = lut[row-1, :]
# flip few randomly
rand_idx = rs.choice(dims, size=dims//nvalues, replace=False)
lut[row, rand_idx] = 1 - lut[row, rand_idx]
assert np.abs(lut[row, :] - lut[row-1, :]).sum() ==dims//nvalues
unique_rows = np.unique(lut, axis=0)
assert len(unique_rows) == len(lut)
return lut
def load_fn(images, rng_state, exemplars_state, lut_state):
rs = RandomState(MT19937(SeedSequence(123456789)))
rng_state[0] = rs
exemplars_state[0] = rs.binomial(n=1, p=0.5, size=HYPERDIMS)
exemplars_state[1] = rs.binomial(n=1, p=0.5, size=HYPERDIMS)
lut_state[0] = make_LUT(2**VALUE_BITS, HYPERDIMS, rs)
assert lut_state[0].shape[0] == val_bins.shape[0]
lut_state[1] = rs.binomial(n=1, p=0.5, size=(2**POS_BITS, HYPERDIMS))
print(exemplars_state)
print(lut_state[0].shape, lut_state[1].shape)
return images[-1], images[-1], rng_state, exemplars_state, lut_state
def quantize_embeds(embeds):
assert np.all(embeds >= val_bins[0])
assert np.all(embeds <= val_bins[-1])
embeds_flat = embeds.flatten()
all_pairs_dist = np.abs(embeds_flat[:, np.newaxis] - val_bins[np.newaxis, :])
closest_bin = np.argmin(all_pairs_dist, axis=-1)
quantized_embeds_flat = val_bins[closest_bin]
quantized_embeds = np.reshape(quantized_embeds_flat, embeds.shape)
closest_bin = np.reshape(closest_bin, embeds.shape)
print(closest_bin.shape, 'values are in bins', closest_bin.min(), 'to', closest_bin.max())
print('abs quant error avg', np.abs(embeds - quantized_embeds).mean())
return quantized_embeds, closest_bin
def update_exemplars(df, rng, exemplars, lut):
embeds = np.array(df['image_embed'].values.tolist()) # df[['image_embed']].to_numpy()
labels = np.array(df['label'].values.tolist(), 'int')
# print(labels, labels.shape)
assert np.all(np.unique(labels) == [0, 1])
labels_zero_idx = (labels == 0).nonzero()[0]
labels_one_idx = (labels == 1).nonzero()[0]
print(labels_zero_idx.shape, " zeros and ", labels_one_idx.shape, " ones")
# 70-30 split
labels_zero_train_idx = rng[0].choice(labels_zero_idx, size=int(.7 * len(labels_zero_idx)), replace=False)
labels_one_train_idx = rng[0].choice(labels_one_idx, size=int(.7 * len(labels_one_idx)), replace=False)
embeds_train = np.concatenate([embeds[labels_zero_train_idx], embeds[labels_one_train_idx]], axis=0)
labels_train = np.concatenate([labels[labels_zero_train_idx], labels[labels_one_train_idx]], axis=0)
print('Training set ', embeds_train.shape, labels_train.shape)
print(np.sum(labels_train == 0), " zeros and ", np.sum(labels_train == 1).sum(), " ones")
labels_zero_test_idx = np.setdiff1d(labels_zero_idx, labels_zero_train_idx)
labels_one_test_idx = np.setdiff1d(labels_one_idx, labels_one_train_idx)
embeds_test = np.concatenate([embeds[labels_zero_test_idx], embeds[labels_one_test_idx]], axis=0)
labels_test = np.concatenate([labels[labels_zero_test_idx], labels[labels_one_test_idx]], axis=0)
print('Test set ', embeds_test.shape, labels_test.shape)
quantized_embeds, closest_bin = quantize_embeds(embeds_train)
# closest bin is nexample X 512
# lut[0] is nvals X dims
# hd_embeds in nexample x 512 x dims
hd_embeds_per_pos = lut[0][closest_bin]
# bundle along pos dimension 512
# lut[1] is 512 x dims
xor = lambda a,b: a*(1.-b) + b*(1.-a)
hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos)
hd_embeds = np.sum(hd_embeds, axis=1) / embeds_train.shape[-1]
hd_embeds[hd_embeds >= 0.5] = 1.
hd_embeds[hd_embeds < 0.5] = 0.
# hd_embeds_integer is nexample x dims
exemplars_integer = [None, None]
exemplars_integer[0] = np.sum(hd_embeds[labels_train == 0], axis=0)
exemplars_integer[1] = np.sum(hd_embeds[labels_train == 1], axis=0)
exemplars[0] = exemplars_integer[0] / np.sum(labels_train == 0)
exemplars[1] = exemplars_integer[1] / np.sum(labels_train == 1)
exemplars[0][exemplars[0] >= 0.5] = 1.
exemplars[0][exemplars[0] < 0.5] = 0.
exemplars[1][exemplars[1] >= 0.5] = 1.
exemplars[1][exemplars[1] < 0.5] = 0.
print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum())
preds = np.zeros(hd_embeds.shape[0])
dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1)
dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1)
preds[dist_to_ex1 < dist_to_ex0] = 1
print(preds.shape, labels_train.shape, np.sum(preds == labels_train))
train_acc = np.sum(preds == labels_train) / len(labels_train)
rng, test_acc = score(embeds_test, labels_test, rng, exemplars, lut)
return rng, exemplars, train_acc, test_acc
def score(embeds, labels, rng, exemplars, lut):
quantized_embeds, closest_bin = quantize_embeds(embeds)
# closest bin is nexample X 512
# lut[0] is nvals X dims
# hd_embeds in nexample x 512 x dims
hd_embeds_per_pos = lut[0][closest_bin]
# bundle along pos dimension 512
# lut[1] is 512 x dims
xor = lambda a,b: a*(1.-b) + b*(1.-a)
hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos)
hd_embeds = np.sum(hd_embeds, axis=1) / embeds.shape[-1]
hd_embeds[hd_embeds >= 0.5] = 1.
hd_embeds[hd_embeds < 0.5] = 0.
# hd_embeds_integer is nexample x dims
print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum())
preds = np.zeros(hd_embeds.shape[0])
dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1)
dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1)
preds[dist_to_ex1 < dist_to_ex0] = 1
print(preds.shape, labels.shape, np.sum(preds == labels), len(labels))
acc = np.sum(preds == labels) / len(labels)
return rng, acc
def predict(embeds, exemplars, lut):
quantized_embeds, closest_bin = quantize_embeds(embeds)
# closest bin is nexample X 512
# lut[0] is nvals X dims
# hd_embeds in nexample x 512 x dims
hd_embeds_per_pos = lut[0][closest_bin]
# bundle along pos dimension 512
# lut[1] is 512 x dims
xor = lambda a,b: a*(1.-b) + b*(1.-a)
hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos)
hd_embeds = np.sum(hd_embeds, axis=1) / embeds.shape[-1]
hd_embeds[hd_embeds >= 0.5] = 1.
hd_embeds[hd_embeds < 0.5] = 0.
# hd_embeds_integer is nexample x dims
# print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum())
dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1)
dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1)
print('dists', dist_to_ex0, dist_to_ex1)
odds = np.abs(dist_to_ex0 - dist_to_ex1).item()
if dist_to_ex1 < dist_to_ex0:
preds = np.array([1., odds])
else:
preds = np.array([odds, 1.])
print(preds)
# preds = np.array([-1. * dist_to_ex0, -1. * dist_to_ex1])
preds = preds / preds.sum()
# print(preds.shape)
print(preds)
return {"๐Ÿ‘": preds[1], "๐Ÿ‘Ž": preds[0]}
with gr.Blocks(title="End-User Personalization") as demo:
img_list = glob('images/**/*.jpg')
random.seed(datetime.now().timestamp())
random.shuffle(img_list)
images = gr.State(img_list)
# start_button = gr.Button(label="Start")
with gr.Row():
image_display = gr.Image()
with gr.Column():
image_fname = gr.Textbox()
preds = gr.Label("Prediction")
# text_display = gr.Text()
with gr.Row():
upvote = gr.Button("๐Ÿ‘")
downvote = gr.Button("๐Ÿ‘Ž")
personalize = gr.Button("Personalize")
with gr.Row():
train_acc = gr.Textbox(label="Train accuracy")
test_acc = gr.Textbox(label="Test accuracy")
annotated_samples = gr.Dataframe(headers=['image_name', 'label', 'image_embed'], row_count=(1, 'dynamic'),
col_count=(3, 'fixed'), label='Annotations', wrap=False)
# HD stuff for incremental updates
rng = gr.State([None])
exemplars_state = gr.State([None, None])
exemplars_state_integer = gr.State([None, None])
lut_state = gr.State([None, None])
upvote.click(update_table_up, inputs=[image_display, image_fname, annotated_samples, images, exemplars_state, lut_state], outputs=[image_display, image_fname, annotated_samples, images, preds])
downvote.click(update_table_down, inputs=[image_display, image_fname, annotated_samples, images, exemplars_state, lut_state], outputs=[image_display, image_fname, annotated_samples, images, preds])
personalize.click(update_exemplars, [annotated_samples, rng, exemplars_state, lut_state], [rng, exemplars_state, train_acc, test_acc])
demo.load(load_fn, inputs=[images, rng, exemplars_state, lut_state], outputs=[image_display, image_fname, rng, exemplars_state, lut_state])
demo.launch(show_error=True, debug=True)