Spaces:
Sleeping
Sleeping
from turtle import title | |
import gradio as gr | |
from transformers import pipeline | |
import numpy as np | |
from PIL import Image | |
from transformers import CLIPProcessor, CLIPModel | |
import pandas as pd | |
from glob import glob | |
import random | |
from datetime import datetime | |
import numpy as np | |
from numpy.random import MT19937 | |
from numpy.random import RandomState, SeedSequence | |
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
HYPERDIMS = 1024 | |
VALUE_BITS = 8 | |
POS_BITS = 9 # CLIP features are 512 dims | |
val_bins = np.linspace(start=-1., stop=1., num=2**VALUE_BITS) | |
print(val_bins.shape, val_bins.min(), val_bins.max(), 'val bins') | |
def extract_features(image): | |
PIL_image = Image.fromarray(np.uint8(image)).convert('RGB') | |
inputs = clip_processor(text=["a photo of a cat", "a photo of a dog"], images=PIL_image, return_tensors="pt", padding=True) | |
outputs = clip_model(**inputs) | |
# print(outputs.image_embeds.shape) | |
return outputs.image_embeds | |
def update_table(img, img_name, df, state, label, exemplars_state, lut_state): | |
img_embeds = extract_features(img).detach().numpy().squeeze().tolist() | |
print(img_name, img.shape, len(img_embeds), 'images left:', len(state)) | |
new_df = pd.DataFrame({'image_name': img_name, 'label': label, 'image_embed': None}, columns=['image_name', 'image_embed', 'label'], index=[0]) | |
# print(new_df) | |
new_df.at[0, 'image_embed'] = img_embeds | |
df = pd.concat([df, new_df]) | |
filt = df["image_name"] != "" | |
df = df[filt] | |
state.pop() | |
t = state[-10:] | |
random.shuffle(t) | |
state = state[:-10] + t | |
idx = -1 | |
next_img = state[idx] | |
preds = predict(extract_features(img).detach().numpy(), exemplars_state, lut_state) | |
return next_img, next_img, df, state, preds | |
def update_table_up(img, img_name, df, state, exemplars_state, lut_state): | |
return update_table(img, img_name, df, state, 1, exemplars_state, lut_state) | |
def update_table_down(img, img_name, df, state, exemplars_state, lut_state): | |
return update_table(img, img_name, df, state, 0, exemplars_state, lut_state) | |
def make_LUT(nvalues, dims, rs): | |
lut = np.zeros(shape=(nvalues, dims)) | |
lut[0, :] = rs.binomial(n=1, p=0.5, size=(dims)) | |
for row in range(1, nvalues): | |
lut[row, :] = lut[row-1, :] | |
# flip few randomly | |
rand_idx = rs.choice(dims, size=dims//nvalues, replace=False) | |
lut[row, rand_idx] = 1 - lut[row, rand_idx] | |
assert np.abs(lut[row, :] - lut[row-1, :]).sum() ==dims//nvalues | |
unique_rows = np.unique(lut, axis=0) | |
assert len(unique_rows) == len(lut) | |
return lut | |
def load_fn(images, rng_state, exemplars_state, lut_state): | |
rs = RandomState(MT19937(SeedSequence(123456789))) | |
rng_state[0] = rs | |
exemplars_state[0] = rs.binomial(n=1, p=0.5, size=HYPERDIMS) | |
exemplars_state[1] = rs.binomial(n=1, p=0.5, size=HYPERDIMS) | |
lut_state[0] = make_LUT(2**VALUE_BITS, HYPERDIMS, rs) | |
assert lut_state[0].shape[0] == val_bins.shape[0] | |
lut_state[1] = rs.binomial(n=1, p=0.5, size=(2**POS_BITS, HYPERDIMS)) | |
print(exemplars_state) | |
print(lut_state[0].shape, lut_state[1].shape) | |
return images[-1], images[-1], rng_state, exemplars_state, lut_state | |
def quantize_embeds(embeds): | |
assert np.all(embeds >= val_bins[0]) | |
assert np.all(embeds <= val_bins[-1]) | |
embeds_flat = embeds.flatten() | |
all_pairs_dist = np.abs(embeds_flat[:, np.newaxis] - val_bins[np.newaxis, :]) | |
closest_bin = np.argmin(all_pairs_dist, axis=-1) | |
quantized_embeds_flat = val_bins[closest_bin] | |
quantized_embeds = np.reshape(quantized_embeds_flat, embeds.shape) | |
closest_bin = np.reshape(closest_bin, embeds.shape) | |
print(closest_bin.shape, 'values are in bins', closest_bin.min(), 'to', closest_bin.max()) | |
print('abs quant error avg', np.abs(embeds - quantized_embeds).mean()) | |
return quantized_embeds, closest_bin | |
def update_exemplars(df, rng, exemplars, lut): | |
embeds = np.array(df['image_embed'].values.tolist()) # df[['image_embed']].to_numpy() | |
labels = np.array(df['label'].values.tolist(), 'int') | |
# print(labels, labels.shape) | |
assert np.all(np.unique(labels) == [0, 1]) | |
labels_zero_idx = (labels == 0).nonzero()[0] | |
labels_one_idx = (labels == 1).nonzero()[0] | |
print(labels_zero_idx.shape, " zeros and ", labels_one_idx.shape, " ones") | |
# 70-30 split | |
labels_zero_train_idx = rng[0].choice(labels_zero_idx, size=int(.7 * len(labels_zero_idx)), replace=False) | |
labels_one_train_idx = rng[0].choice(labels_one_idx, size=int(.7 * len(labels_one_idx)), replace=False) | |
embeds_train = np.concatenate([embeds[labels_zero_train_idx], embeds[labels_one_train_idx]], axis=0) | |
labels_train = np.concatenate([labels[labels_zero_train_idx], labels[labels_one_train_idx]], axis=0) | |
print('Training set ', embeds_train.shape, labels_train.shape) | |
print(np.sum(labels_train == 0), " zeros and ", np.sum(labels_train == 1).sum(), " ones") | |
labels_zero_test_idx = np.setdiff1d(labels_zero_idx, labels_zero_train_idx) | |
labels_one_test_idx = np.setdiff1d(labels_one_idx, labels_one_train_idx) | |
embeds_test = np.concatenate([embeds[labels_zero_test_idx], embeds[labels_one_test_idx]], axis=0) | |
labels_test = np.concatenate([labels[labels_zero_test_idx], labels[labels_one_test_idx]], axis=0) | |
print('Test set ', embeds_test.shape, labels_test.shape) | |
quantized_embeds, closest_bin = quantize_embeds(embeds_train) | |
# closest bin is nexample X 512 | |
# lut[0] is nvals X dims | |
# hd_embeds in nexample x 512 x dims | |
hd_embeds_per_pos = lut[0][closest_bin] | |
# bundle along pos dimension 512 | |
# lut[1] is 512 x dims | |
xor = lambda a,b: a*(1.-b) + b*(1.-a) | |
hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos) | |
hd_embeds = np.sum(hd_embeds, axis=1) / embeds_train.shape[-1] | |
hd_embeds[hd_embeds >= 0.5] = 1. | |
hd_embeds[hd_embeds < 0.5] = 0. | |
# hd_embeds_integer is nexample x dims | |
exemplars_integer = [None, None] | |
exemplars_integer[0] = np.sum(hd_embeds[labels_train == 0], axis=0) | |
exemplars_integer[1] = np.sum(hd_embeds[labels_train == 1], axis=0) | |
exemplars[0] = exemplars_integer[0] / np.sum(labels_train == 0) | |
exemplars[1] = exemplars_integer[1] / np.sum(labels_train == 1) | |
exemplars[0][exemplars[0] >= 0.5] = 1. | |
exemplars[0][exemplars[0] < 0.5] = 0. | |
exemplars[1][exemplars[1] >= 0.5] = 1. | |
exemplars[1][exemplars[1] < 0.5] = 0. | |
print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum()) | |
preds = np.zeros(hd_embeds.shape[0]) | |
dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1) | |
dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1) | |
preds[dist_to_ex1 < dist_to_ex0] = 1 | |
print(preds.shape, labels_train.shape, np.sum(preds == labels_train)) | |
train_acc = np.sum(preds == labels_train) / len(labels_train) | |
rng, test_acc = score(embeds_test, labels_test, rng, exemplars, lut) | |
return rng, exemplars, train_acc, test_acc | |
def score(embeds, labels, rng, exemplars, lut): | |
quantized_embeds, closest_bin = quantize_embeds(embeds) | |
# closest bin is nexample X 512 | |
# lut[0] is nvals X dims | |
# hd_embeds in nexample x 512 x dims | |
hd_embeds_per_pos = lut[0][closest_bin] | |
# bundle along pos dimension 512 | |
# lut[1] is 512 x dims | |
xor = lambda a,b: a*(1.-b) + b*(1.-a) | |
hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos) | |
hd_embeds = np.sum(hd_embeds, axis=1) / embeds.shape[-1] | |
hd_embeds[hd_embeds >= 0.5] = 1. | |
hd_embeds[hd_embeds < 0.5] = 0. | |
# hd_embeds_integer is nexample x dims | |
print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum()) | |
preds = np.zeros(hd_embeds.shape[0]) | |
dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1) | |
dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1) | |
preds[dist_to_ex1 < dist_to_ex0] = 1 | |
print(preds.shape, labels.shape, np.sum(preds == labels), len(labels)) | |
acc = np.sum(preds == labels) / len(labels) | |
return rng, acc | |
def predict(embeds, exemplars, lut): | |
quantized_embeds, closest_bin = quantize_embeds(embeds) | |
# closest bin is nexample X 512 | |
# lut[0] is nvals X dims | |
# hd_embeds in nexample x 512 x dims | |
hd_embeds_per_pos = lut[0][closest_bin] | |
# bundle along pos dimension 512 | |
# lut[1] is 512 x dims | |
xor = lambda a,b: a*(1.-b) + b*(1.-a) | |
hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos) | |
hd_embeds = np.sum(hd_embeds, axis=1) / embeds.shape[-1] | |
hd_embeds[hd_embeds >= 0.5] = 1. | |
hd_embeds[hd_embeds < 0.5] = 0. | |
# hd_embeds_integer is nexample x dims | |
# print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum()) | |
dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1) | |
dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1) | |
print('dists', dist_to_ex0, dist_to_ex1) | |
odds = np.abs(dist_to_ex0 - dist_to_ex1).item() | |
if dist_to_ex1 < dist_to_ex0: | |
preds = np.array([1., odds]) | |
else: | |
preds = np.array([odds, 1.]) | |
print(preds) | |
# preds = np.array([-1. * dist_to_ex0, -1. * dist_to_ex1]) | |
preds = preds / preds.sum() | |
# print(preds.shape) | |
print(preds) | |
return {"๐": preds[1], "๐": preds[0]} | |
with gr.Blocks(title="End-User Personalization") as demo: | |
img_list = glob('images/**/*.jpg') | |
random.seed(datetime.now().timestamp()) | |
random.shuffle(img_list) | |
images = gr.State(img_list) | |
# start_button = gr.Button(label="Start") | |
with gr.Row(): | |
image_display = gr.Image() | |
with gr.Column(): | |
image_fname = gr.Textbox() | |
preds = gr.Label("Prediction") | |
# text_display = gr.Text() | |
with gr.Row(): | |
upvote = gr.Button("๐") | |
downvote = gr.Button("๐") | |
personalize = gr.Button("Personalize") | |
with gr.Row(): | |
train_acc = gr.Textbox(label="Train accuracy") | |
test_acc = gr.Textbox(label="Test accuracy") | |
annotated_samples = gr.Dataframe(headers=['image_name', 'label', 'image_embed'], row_count=(1, 'dynamic'), | |
col_count=(3, 'fixed'), label='Annotations', wrap=False) | |
# HD stuff for incremental updates | |
rng = gr.State([None]) | |
exemplars_state = gr.State([None, None]) | |
exemplars_state_integer = gr.State([None, None]) | |
lut_state = gr.State([None, None]) | |
upvote.click(update_table_up, inputs=[image_display, image_fname, annotated_samples, images, exemplars_state, lut_state], outputs=[image_display, image_fname, annotated_samples, images, preds]) | |
downvote.click(update_table_down, inputs=[image_display, image_fname, annotated_samples, images, exemplars_state, lut_state], outputs=[image_display, image_fname, annotated_samples, images, preds]) | |
personalize.click(update_exemplars, [annotated_samples, rng, exemplars_state, lut_state], [rng, exemplars_state, train_acc, test_acc]) | |
demo.load(load_fn, inputs=[images, rng, exemplars_state, lut_state], outputs=[image_display, image_fname, rng, exemplars_state, lut_state]) | |
demo.launch(show_error=True, debug=True) |