# -*- coding: utf-8 -*- """message_bottle.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1I47sLakpuwERGzn-XoNct67mwiDS1mQD """ import torch import torch.nn as nn import torch.nn.functional as F torch.set_float32_matmul_precision('high') from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM class BottleneckT5Autoencoder: def __init__(self, model_path: str, device='cuda'): self.device = device self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=512, torch_dtype=torch.bfloat16) self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(self.device) self.model.eval() # self.model = torch.compile(self.model) def embed(self, text: str) -> torch.FloatTensor: inputs = self.tokenizer(text, return_tensors='pt', padding=True).to(self.device) decoder_inputs = self.tokenizer('', return_tensors='pt').to(self.device) return self.model( **inputs, decoder_input_ids=decoder_inputs['input_ids'], encode_only=True, ) def generate_from_latent(self, latent: torch.FloatTensor, max_length=512, temperature=1., top_p=.8, length_penalty=10, min_new_tokens=30) -> str: dummy_text = '.' dummy = self.embed(dummy_text) perturb_vector = latent - dummy self.model.perturb_vector = perturb_vector input_ids = self.tokenizer(dummy_text, return_tensors='pt').to(self.device).input_ids output = self.model.generate( input_ids=input_ids, max_length=max_length, do_sample=True, temperature=temperature, top_p=top_p, num_return_sequences=1, length_penalty=length_penalty, min_new_tokens=min_new_tokens, # num_beams=8, ) return self.tokenizer.decode(output[0], skip_special_tokens=True) autoencoder = BottleneckT5Autoencoder(model_path='thesephist/contra-bottleneck-t5-xl-wikipedia') import gradio as gr import numpy as np from sklearn.svm import SVC from sklearn.inspection import permutation_importance from sklearn import preprocessing import pandas as pd import random import time dtype = torch.bfloat16 torch.set_grad_enabled(False) prompt_list = [p for p in list(set( pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] start_time = time.time() ####################### Setup Model # TODO put back # @spaces.GPU() def generate(prompt, in_embs=None,): if prompt != '': print(prompt) in_embs = in_embs / in_embs.abs().max() * .15 if in_embs != None else None in_embs = .9 * in_embs.to('cuda') + .5 * autoencoder.embed(prompt).to('cuda') if in_embs != None else autoencoder.embed(prompt).to('cuda') else: print('From embeds.') in_embs = in_embs / in_embs.abs().max() * .15 text = autoencoder.generate_from_latent(in_embs.to('cuda'), temperature=.3, top_p=.99, min_new_tokens=5) in_embs = autoencoder.embed(prompt) return text, in_embs.to('cpu') ####################### # TODO add to state instead of shared across all glob_idx = 0 def next_one(embs, ys, calibrate_prompts): global glob_idx glob_idx = glob_idx + 1 with torch.no_grad(): if len(calibrate_prompts) > 0: print('######### Calibrating with sample prompts #########') prompt = calibrate_prompts.pop(0) print(prompt) text, img_embs = generate(prompt) embs += img_embs print(len(embs)) return text, embs, ys, calibrate_prompts else: print('######### Roaming #########') # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike' if len(list(set(ys))) <= 1: embs.append(.01*torch.randn(2048)) embs.append(.01*torch.randn(2048)) ys.append(0) ys.append(1) if len(list(ys)) < 10: embs += [.01*torch.randn(2048)] * 3 ys += [0] * 3 pos_indices = [i for i in range(len(embs)) if ys[i] == 1] neg_indices = [i for i in range(len(embs)) if ys[i] == 0] # the embs & ys stay tied by index but we shuffle to drop randomly random.shuffle(pos_indices) random.shuffle(neg_indices) #if len(pos_indices) - len(neg_indices) > 48 and len(pos_indices) > 80: # pos_indices = pos_indices[32:] if len(neg_indices) - len(pos_indices) > 48/16 and len(pos_indices) > 6: pos_indices = pos_indices[5:] if len(neg_indices) - len(pos_indices) > 48/16 and len(neg_indices) > 6: neg_indices = neg_indices[5:] if len(neg_indices) > 25: neg_indices = neg_indices[1:] print(len(pos_indices), len(neg_indices)) indices = pos_indices + neg_indices embs = [embs[i] for i in indices] ys = [ys[i] for i in indices] indices = list(range(len(embs))) # also add the latest 0 and the latest 1 has_0 = False has_1 = False for i in reversed(range(len(ys))): if ys[i] == 0 and has_0 == False: indices.append(i) has_0 = True elif ys[i] == 1 and has_1 == False: indices.append(i) has_1 = True if has_0 and has_1: break # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749); # this ends up adding a rating but losing an embedding, it seems. # let's take off a rating if so to continue without indexing errors. if len(ys) > len(embs): print('ys are longer than embs; popping latest rating') ys.pop(-1) feature_embs = np.array(torch.stack([embs[i].to('cpu') for i in indices]).to('cpu')) scaler = preprocessing.StandardScaler().fit(feature_embs) feature_embs = scaler.transform(feature_embs) chosen_y = np.array([ys[i] for i in indices]) print('Gathering coefficients') lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=.1).fit(feature_embs, chosen_y) coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) coef_ = coef_ / coef_.abs().max() * 3 print(coef_.shape, 'COEF') print('Gathered') rng_prompt = random.choice(prompt_list) w = 1# if len(embs) % 2 == 0 else 0 im_emb = w * coef_.to(dtype=dtype) prompt= '' if glob_idx % 3 != 0 else rng_prompt text, im_emb = generate(prompt, im_emb) embs += im_emb return text, embs, ys, calibrate_prompts def start(_, embs, ys, calibrate_prompts): text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) return [ gr.Button(value='Like (L)', interactive=True), gr.Button(value='Neither (Space)', interactive=True), gr.Button(value='Dislike (A)', interactive=True), gr.Button(value='Start', interactive=False), text, embs, ys, calibrate_prompts ] def choose(text, choice, embs, ys, calibrate_prompts): if choice == 'Like (L)': choice = 1 elif choice == 'Neither (Space)': embs = embs[:-1] text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) return text, embs, ys, calibrate_prompts else: choice = 0 # if we detected NSFW, leave that area of latent space regardless of how they rated chosen. # TODO skip allowing rating if text == None: print('NSFW -- choice is disliked') choice = 0 ys += [choice]*1 text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) return text, embs, ys, calibrate_prompts css = '''.gradio-container{max-width: 700px !important} #description{text-align: center} #description h1, #description h3{display: block} #description p{margin-top: 0} .fade-in-out {animation: fadeInOut 3s forwards} @keyframes fadeInOut { 0% { background: var(--bg-color); } 100% { background: var(--button-secondary-background-fill); } } ''' js_head = ''' ''' with gr.Blocks(css=css, head=js_head) as demo: gr.Markdown('''# Compass ### Generative Recommenders for Exporation of Text Explore the latent space without prompting based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/). ''', elem_id="description") embs = gr.State([]) ys = gr.State([]) calibrate_prompts = gr.State([ 'the moon is melting into my glass of tea', 'a sea slug -- pair of claws scuttling -- jelly fish glowing', 'an adorable creature. It may be a goblin or a pig or a slug.', 'an animation about a gorgeous nebula', 'a sketch of an impressive mountain by da vinci', 'a watercolor painting: the octopus writhes', ]) def l(): return None with gr.Row(elem_id='output-image'): text = gr.Textbox(interactive=False, elem_id="text") with gr.Row(equal_height=True): b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike") b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither") b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like") b1.click( choose, [text, b1, embs, ys, calibrate_prompts], [text, embs, ys, calibrate_prompts] ) b2.click( choose, [text, b2, embs, ys, calibrate_prompts], [text, embs, ys, calibrate_prompts] ) b3.click( choose, [text, b3, embs, ys, calibrate_prompts], [text, embs, ys, calibrate_prompts] ) with gr.Row(): b4 = gr.Button(value='Start') b4.click(start, [b4, embs, ys, calibrate_prompts], [b1, b2, b3, b4, text, embs, ys, calibrate_prompts]) with gr.Row(): html = gr.HTML('''