from turtle import title import gradio as gr from transformers import pipeline import numpy as np from PIL import Image from transformers import CLIPProcessor, CLIPModel import pandas as pd from glob import glob import random from datetime import datetime import numpy as np from numpy.random import MT19937 from numpy.random import RandomState, SeedSequence clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") HYPERDIMS = 1024 VALUE_BITS = 8 POS_BITS = 9 # CLIP features are 512 dims val_bins = np.linspace(start=-1., stop=1., num=2**VALUE_BITS) print(val_bins.shape, val_bins.min(), val_bins.max(), 'val bins') def extract_features(image): PIL_image = Image.fromarray(np.uint8(image)).convert('RGB') inputs = clip_processor(text=["a photo of a cat", "a photo of a dog"], images=PIL_image, return_tensors="pt", padding=True) outputs = clip_model(**inputs) # print(outputs.image_embeds.shape) return outputs.image_embeds def update_table(img, img_name, df, state, label, exemplars_state, lut_state): img_embeds = extract_features(img).detach().numpy().squeeze().tolist() print(img_name, img.shape, len(img_embeds), 'images left:', len(state)) new_df = pd.DataFrame({'image_name': img_name, 'label': label, 'image_embed': None}, columns=['image_name', 'image_embed', 'label'], index=[0]) # print(new_df) new_df.at[0, 'image_embed'] = img_embeds df = pd.concat([df, new_df]) filt = df["image_name"] != "" df = df[filt] state.pop() t = state[-10:] random.shuffle(t) state = state[:-10] + t idx = -1 next_img = state[idx] preds = predict(extract_features(img).detach().numpy(), exemplars_state, lut_state) return next_img, next_img, df, state, preds def update_table_up(img, img_name, df, state, exemplars_state, lut_state): return update_table(img, img_name, df, state, 1, exemplars_state, lut_state) def update_table_down(img, img_name, df, state, exemplars_state, lut_state): return update_table(img, img_name, df, state, 0, exemplars_state, lut_state) def make_LUT(nvalues, dims, rs): lut = np.zeros(shape=(nvalues, dims)) lut[0, :] = rs.binomial(n=1, p=0.5, size=(dims)) for row in range(1, nvalues): lut[row, :] = lut[row-1, :] # flip few randomly rand_idx = rs.choice(dims, size=dims//nvalues, replace=False) lut[row, rand_idx] = 1 - lut[row, rand_idx] assert np.abs(lut[row, :] - lut[row-1, :]).sum() ==dims//nvalues unique_rows = np.unique(lut, axis=0) assert len(unique_rows) == len(lut) return lut def load_fn(images, rng_state, exemplars_state, lut_state): rs = RandomState(MT19937(SeedSequence(123456789))) rng_state[0] = rs exemplars_state[0] = rs.binomial(n=1, p=0.5, size=HYPERDIMS) exemplars_state[1] = rs.binomial(n=1, p=0.5, size=HYPERDIMS) lut_state[0] = make_LUT(2**VALUE_BITS, HYPERDIMS, rs) assert lut_state[0].shape[0] == val_bins.shape[0] lut_state[1] = rs.binomial(n=1, p=0.5, size=(2**POS_BITS, HYPERDIMS)) print(exemplars_state) print(lut_state[0].shape, lut_state[1].shape) return images[-1], images[-1], rng_state, exemplars_state, lut_state def quantize_embeds(embeds): assert np.all(embeds >= val_bins[0]) assert np.all(embeds <= val_bins[-1]) embeds_flat = embeds.flatten() all_pairs_dist = np.abs(embeds_flat[:, np.newaxis] - val_bins[np.newaxis, :]) closest_bin = np.argmin(all_pairs_dist, axis=-1) quantized_embeds_flat = val_bins[closest_bin] quantized_embeds = np.reshape(quantized_embeds_flat, embeds.shape) closest_bin = np.reshape(closest_bin, embeds.shape) print(closest_bin.shape, 'values are in bins', closest_bin.min(), 'to', closest_bin.max()) print('abs quant error avg', np.abs(embeds - quantized_embeds).mean()) return quantized_embeds, closest_bin def update_exemplars(df, rng, exemplars, lut): embeds = np.array(df['image_embed'].values.tolist()) # df[['image_embed']].to_numpy() labels = np.array(df['label'].values.tolist(), 'int') # print(labels, labels.shape) assert np.all(np.unique(labels) == [0, 1]) labels_zero_idx = (labels == 0).nonzero()[0] labels_one_idx = (labels == 1).nonzero()[0] print(labels_zero_idx.shape, " zeros and ", labels_one_idx.shape, " ones") # 70-30 split labels_zero_train_idx = rng[0].choice(labels_zero_idx, size=int(.7 * len(labels_zero_idx)), replace=False) labels_one_train_idx = rng[0].choice(labels_one_idx, size=int(.7 * len(labels_one_idx)), replace=False) embeds_train = np.concatenate([embeds[labels_zero_train_idx], embeds[labels_one_train_idx]], axis=0) labels_train = np.concatenate([labels[labels_zero_train_idx], labels[labels_one_train_idx]], axis=0) print('Training set ', embeds_train.shape, labels_train.shape) print(np.sum(labels_train == 0), " zeros and ", np.sum(labels_train == 1).sum(), " ones") labels_zero_test_idx = np.setdiff1d(labels_zero_idx, labels_zero_train_idx) labels_one_test_idx = np.setdiff1d(labels_one_idx, labels_one_train_idx) embeds_test = np.concatenate([embeds[labels_zero_test_idx], embeds[labels_one_test_idx]], axis=0) labels_test = np.concatenate([labels[labels_zero_test_idx], labels[labels_one_test_idx]], axis=0) print('Test set ', embeds_test.shape, labels_test.shape) quantized_embeds, closest_bin = quantize_embeds(embeds_train) # closest bin is nexample X 512 # lut[0] is nvals X dims # hd_embeds in nexample x 512 x dims hd_embeds_per_pos = lut[0][closest_bin] # bundle along pos dimension 512 # lut[1] is 512 x dims xor = lambda a,b: a*(1.-b) + b*(1.-a) hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos) hd_embeds = np.sum(hd_embeds, axis=1) / embeds_train.shape[-1] hd_embeds[hd_embeds >= 0.5] = 1. hd_embeds[hd_embeds < 0.5] = 0. # hd_embeds_integer is nexample x dims exemplars_integer = [None, None] exemplars_integer[0] = np.sum(hd_embeds[labels_train == 0], axis=0) exemplars_integer[1] = np.sum(hd_embeds[labels_train == 1], axis=0) exemplars[0] = exemplars_integer[0] / np.sum(labels_train == 0) exemplars[1] = exemplars_integer[1] / np.sum(labels_train == 1) exemplars[0][exemplars[0] >= 0.5] = 1. exemplars[0][exemplars[0] < 0.5] = 0. exemplars[1][exemplars[1] >= 0.5] = 1. exemplars[1][exemplars[1] < 0.5] = 0. print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum()) preds = np.zeros(hd_embeds.shape[0]) dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1) dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1) preds[dist_to_ex1 < dist_to_ex0] = 1 print(preds.shape, labels_train.shape, np.sum(preds == labels_train)) train_acc = np.sum(preds == labels_train) / len(labels_train) rng, test_acc = score(embeds_test, labels_test, rng, exemplars, lut) return rng, exemplars, train_acc, test_acc def score(embeds, labels, rng, exemplars, lut): quantized_embeds, closest_bin = quantize_embeds(embeds) # closest bin is nexample X 512 # lut[0] is nvals X dims # hd_embeds in nexample x 512 x dims hd_embeds_per_pos = lut[0][closest_bin] # bundle along pos dimension 512 # lut[1] is 512 x dims xor = lambda a,b: a*(1.-b) + b*(1.-a) hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos) hd_embeds = np.sum(hd_embeds, axis=1) / embeds.shape[-1] hd_embeds[hd_embeds >= 0.5] = 1. hd_embeds[hd_embeds < 0.5] = 0. # hd_embeds_integer is nexample x dims print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum()) preds = np.zeros(hd_embeds.shape[0]) dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1) dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1) preds[dist_to_ex1 < dist_to_ex0] = 1 print(preds.shape, labels.shape, np.sum(preds == labels), len(labels)) acc = np.sum(preds == labels) / len(labels) return rng, acc def predict(embeds, exemplars, lut): quantized_embeds, closest_bin = quantize_embeds(embeds) # closest bin is nexample X 512 # lut[0] is nvals X dims # hd_embeds in nexample x 512 x dims hd_embeds_per_pos = lut[0][closest_bin] # bundle along pos dimension 512 # lut[1] is 512 x dims xor = lambda a,b: a*(1.-b) + b*(1.-a) hd_embeds = xor(lut[1][np.newaxis, ...], hd_embeds_per_pos) hd_embeds = np.sum(hd_embeds, axis=1) / embeds.shape[-1] hd_embeds[hd_embeds >= 0.5] = 1. hd_embeds[hd_embeds < 0.5] = 0. # hd_embeds_integer is nexample x dims # print(exemplars[0].shape, exemplars[1].shape, np.abs(exemplars[0] - exemplars[1]).sum()) dist_to_ex0 = np.abs(hd_embeds - exemplars[0][np.newaxis, ...]).sum(axis=-1) dist_to_ex1 = np.abs(hd_embeds - exemplars[1][np.newaxis, ...]).sum(axis=-1) print('dists', dist_to_ex0, dist_to_ex1) odds = np.abs(dist_to_ex0 - dist_to_ex1).item() if dist_to_ex1 < dist_to_ex0: preds = np.array([1., odds]) else: preds = np.array([odds, 1.]) print(preds) # preds = np.array([-1. * dist_to_ex0, -1. * dist_to_ex1]) preds = preds / preds.sum() # print(preds.shape) print(preds) return {"👍": preds[1], "👎": preds[0]} with gr.Blocks(title="End-User Personalization") as demo: img_list = glob('images/**/*.jpg') random.seed(datetime.now().timestamp()) random.shuffle(img_list) images = gr.State(img_list) # start_button = gr.Button(label="Start") with gr.Row(): image_display = gr.Image() with gr.Column(): image_fname = gr.Textbox() preds = gr.Label("Prediction") # text_display = gr.Text() with gr.Row(): upvote = gr.Button("👍") downvote = gr.Button("👎") personalize = gr.Button("Personalize") with gr.Row(): train_acc = gr.Textbox(label="Train accuracy") test_acc = gr.Textbox(label="Test accuracy") annotated_samples = gr.Dataframe(headers=['image_name', 'label', 'image_embed'], row_count=(1, 'dynamic'), col_count=(3, 'fixed'), label='Annotations', wrap=False) # HD stuff for incremental updates rng = gr.State([None]) exemplars_state = gr.State([None, None]) exemplars_state_integer = gr.State([None, None]) lut_state = gr.State([None, None]) upvote.click(update_table_up, inputs=[image_display, image_fname, annotated_samples, images, exemplars_state, lut_state], outputs=[image_display, image_fname, annotated_samples, images, preds]) downvote.click(update_table_down, inputs=[image_display, image_fname, annotated_samples, images, exemplars_state, lut_state], outputs=[image_display, image_fname, annotated_samples, images, preds]) personalize.click(update_exemplars, [annotated_samples, rng, exemplars_state, lut_state], [rng, exemplars_state, train_acc, test_acc]) demo.load(load_fn, inputs=[images, rng, exemplars_state, lut_state], outputs=[image_display, image_fname, rng, exemplars_state, lut_state]) demo.launch(show_error=True, debug=True)