# -*- coding: utf-8 -*- """message_bottle.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1I47sLakpuwERGzn-XoNct67mwiDS1mQD """ import torch import torch.nn as nn import torch.nn.functional as F torch.set_float32_matmul_precision('high') from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM class BottleneckT5Autoencoder: def __init__(self, model_path: str, device='cuda'): self.device = device self.tokenizer = AutoTokenizer.from_pretrained(model_path, model_max_length=512, torch_dtype=torch.bfloat16) self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(self.device) self.model.eval() # self.model = torch.compile(self.model) def embed(self, text: str) -> torch.FloatTensor: inputs = self.tokenizer(text, return_tensors='pt', padding=True).to(self.device) decoder_inputs = self.tokenizer('', return_tensors='pt').to(self.device) return self.model( **inputs, decoder_input_ids=decoder_inputs['input_ids'], encode_only=True, ) def generate_from_latent(self, latent: torch.FloatTensor, max_length=512, temperature=1., top_p=.8, length_penalty=10, min_new_tokens=30) -> str: dummy_text = '.' dummy = self.embed(dummy_text) perturb_vector = latent - dummy self.model.perturb_vector = perturb_vector input_ids = self.tokenizer(dummy_text, return_tensors='pt').to(self.device).input_ids output = self.model.generate( input_ids=input_ids, max_length=max_length, do_sample=True, temperature=temperature, top_p=top_p, num_return_sequences=1, length_penalty=length_penalty, min_new_tokens=min_new_tokens, # num_beams=8, ) return self.tokenizer.decode(output[0], skip_special_tokens=True) autoencoder = BottleneckT5Autoencoder(model_path='thesephist/contra-bottleneck-t5-xl-wikipedia') import gradio as gr import numpy as np from sklearn.svm import SVC from sklearn.inspection import permutation_importance from sklearn import preprocessing import pandas as pd import random import time dtype = torch.bfloat16 torch.set_grad_enabled(False) prompt_list = [p for p in list(set( pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] start_time = time.time() ####################### Setup Model # TODO put back # @spaces.GPU() def generate(prompt, in_embs=None,): if prompt != '': print(prompt) in_embs = in_embs / in_embs.abs().max() * .15 if in_embs != None else None in_embs = .9 * in_embs.to('cuda') + .5 * autoencoder.embed(prompt).to('cuda') if in_embs != None else autoencoder.embed(prompt).to('cuda') else: print('From embeds.') in_embs = in_embs / in_embs.abs().max() * .15 text = autoencoder.generate_from_latent(in_embs.to('cuda'), temperature=.3, top_p=.99, min_new_tokens=5) in_embs = autoencoder.embed(prompt) return text, in_embs.to('cpu') ####################### # TODO add to state instead of shared across all glob_idx = 0 def next_one(embs, ys, calibrate_prompts): global glob_idx glob_idx = glob_idx + 1 with torch.no_grad(): if len(calibrate_prompts) > 0: print('######### Calibrating with sample prompts #########') prompt = calibrate_prompts.pop(0) print(prompt) text, img_embs = generate(prompt) embs += img_embs print(len(embs)) return text, embs, ys, calibrate_prompts else: print('######### Roaming #########') # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike' if len(list(set(ys))) <= 1: embs.append(.01*torch.randn(2048)) embs.append(.01*torch.randn(2048)) ys.append(0) ys.append(1) if len(list(ys)) < 10: embs += [.01*torch.randn(2048)] * 3 ys += [0] * 3 pos_indices = [i for i in range(len(embs)) if ys[i] == 1] neg_indices = [i for i in range(len(embs)) if ys[i] == 0] # the embs & ys stay tied by index but we shuffle to drop randomly random.shuffle(pos_indices) random.shuffle(neg_indices) #if len(pos_indices) - len(neg_indices) > 48 and len(pos_indices) > 80: # pos_indices = pos_indices[32:] if len(neg_indices) - len(pos_indices) > 48/16 and len(pos_indices) > 6: pos_indices = pos_indices[5:] if len(neg_indices) - len(pos_indices) > 48/16 and len(neg_indices) > 6: neg_indices = neg_indices[5:] if len(neg_indices) > 25: neg_indices = neg_indices[1:] print(len(pos_indices), len(neg_indices)) indices = pos_indices + neg_indices embs = [embs[i] for i in indices] ys = [ys[i] for i in indices] indices = list(range(len(embs))) # also add the latest 0 and the latest 1 has_0 = False has_1 = False for i in reversed(range(len(ys))): if ys[i] == 0 and has_0 == False: indices.append(i) has_0 = True elif ys[i] == 1 and has_1 == False: indices.append(i) has_1 = True if has_0 and has_1: break # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749); # this ends up adding a rating but losing an embedding, it seems. # let's take off a rating if so to continue without indexing errors. if len(ys) > len(embs): print('ys are longer than embs; popping latest rating') ys.pop(-1) feature_embs = np.array(torch.stack([embs[i].to('cpu') for i in indices]).to('cpu')) scaler = preprocessing.StandardScaler().fit(feature_embs) feature_embs = scaler.transform(feature_embs) chosen_y = np.array([ys[i] for i in indices]) print('Gathering coefficients') lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=.1).fit(feature_embs, chosen_y) coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) coef_ = coef_ / coef_.abs().max() * 3 print(coef_.shape, 'COEF') print('Gathered') rng_prompt = random.choice(prompt_list) w = 1# if len(embs) % 2 == 0 else 0 im_emb = w * coef_.to(dtype=dtype) prompt= '' if glob_idx % 3 != 0 else rng_prompt text, im_emb = generate(prompt, im_emb) embs += im_emb return text, embs, ys, calibrate_prompts def start(_, embs, ys, calibrate_prompts): text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) return [ gr.Button(value='Like (L)', interactive=True), gr.Button(value='Neither (Space)', interactive=True), gr.Button(value='Dislike (A)', interactive=True), gr.Button(value='Start', interactive=False), text, embs, ys, calibrate_prompts ] def choose(text, choice, embs, ys, calibrate_prompts): if choice == 'Like (L)': choice = 1 elif choice == 'Neither (Space)': embs = embs[:-1] text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) return text, embs, ys, calibrate_prompts else: choice = 0 # if we detected NSFW, leave that area of latent space regardless of how they rated chosen. # TODO skip allowing rating if text == None: print('NSFW -- choice is disliked') choice = 0 ys += [choice]*1 text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) return text, embs, ys, calibrate_prompts css = '''.gradio-container{max-width: 700px !important} #description{text-align: center} #description h1, #description h3{display: block} #description p{margin-top: 0} .fade-in-out {animation: fadeInOut 3s forwards} @keyframes fadeInOut { 0% { background: var(--bg-color); } 100% { background: var(--button-secondary-background-fill); } } ''' js_head = ''' ''' with gr.Blocks(css=css, head=js_head) as demo: gr.Markdown('''# Compass ### Generative Recommenders for Exporation of Text Explore the latent space without prompting based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/). ''', elem_id="description") embs = gr.State([]) ys = gr.State([]) calibrate_prompts = gr.State([ 'the moon is melting into my glass of tea', 'a sea slug -- pair of claws scuttling -- jelly fish glowing', 'an adorable creature. It may be a goblin or a pig or a slug.', 'an animation about a gorgeous nebula', 'a sketch of an impressive mountain by da vinci', 'a watercolor painting: the octopus writhes', ]) def l(): return None with gr.Row(elem_id='output-image'): text = gr.Textbox(interactive=False, elem_id="text") with gr.Row(equal_height=True): b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike") b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither") b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like") b1.click( choose, [text, b1, embs, ys, calibrate_prompts], [text, embs, ys, calibrate_prompts] ) b2.click( choose, [text, b2, embs, ys, calibrate_prompts], [text, embs, ys, calibrate_prompts] ) b3.click( choose, [text, b3, embs, ys, calibrate_prompts], [text, embs, ys, calibrate_prompts] ) with gr.Row(): b4 = gr.Button(value='Start') b4.click(start, [b4, embs, ys, calibrate_prompts], [b1, b2, b3, b4, text, embs, ys, calibrate_prompts]) with gr.Row(): html = gr.HTML('''
You will calibrate for several prompts and then roam.


Note that while the model is unlikely to produce NSFW text, this may still occur, and users should avoid NSFW content when rating.

Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback. ''') demo.launch(share=True)